Skip to content

Commit

Permalink
feat: Implement cache token ratio for more precise token pricing
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Mar 7, 2025
1 parent 81137e0 commit 4f194f4
Show file tree
Hide file tree
Showing 18 changed files with 258 additions and 71 deletions.
2 changes: 1 addition & 1 deletion controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice)
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, 0, 0.0, priceData.ModelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
Expand Down
5 changes: 3 additions & 2 deletions controller/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
)

func GetPricing(c *gin.Context) {
Expand Down Expand Up @@ -39,7 +40,7 @@ func GetPricing(c *gin.Context) {
}

func ResetModelRatio(c *gin.Context) {
defaultStr := setting.DefaultModelRatio2JSONString()
defaultStr := operation_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
Expand All @@ -48,7 +49,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
err = setting.UpdateModelRatioByJSONString(defaultStr)
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
Expand Down
15 changes: 9 additions & 6 deletions model/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = setting.ModelPrice2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = setting.CompletionRatio2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["ChatLink2"] = common.ChatLink2
Expand Down Expand Up @@ -344,15 +345,17 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = setting.UpdateModelRatioByJSONString(value)
err = operation_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = setting.UpdateCompletionRatioByJSONString(value)
err = operation_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = setting.UpdateModelPriceByJSONString(value)
err = operation_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = operation_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChatLink":
Expand Down
8 changes: 4 additions & 4 deletions model/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package model

import (
"one-api/common"
"one-api/setting"
"one-api/setting/operation_setting"
"sync"
"time"
)
Expand Down Expand Up @@ -65,14 +65,14 @@ func updatePricing() {
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := setting.GetModelPrice(model, false)
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
modelRatio, _ := setting.GetModelRatio(model)
modelRatio, _ := operation_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
pricing.CompletionRatio = setting.GetCompletionRatio(model)
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)
Expand Down
11 changes: 8 additions & 3 deletions relay/helper/price.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,42 @@ import (
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
)

type PriceData struct {
ModelPrice float64
ModelRatio float64
CompletionRatio float64
CacheRatio float64
GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
}

func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := setting.GetModelPrice(info.OriginModelName, false)
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
var cacheRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
}
var success bool
modelRatio, success = setting.GetModelRatio(info.OriginModelName)
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
if !success {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
}
}
completionRatio = setting.GetCompletionRatio(info.OriginModelName)
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
Expand All @@ -49,6 +53,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
CompletionRatio: completionRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}, nil
}
9 changes: 5 additions & 4 deletions relay/relay-mj.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -157,10 +158,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice, success := setting.GetModelPrice(modelName, true)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
Expand Down Expand Up @@ -463,10 +464,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)

modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice, success := setting.GetModelPrice(modelName, true)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
Expand Down
17 changes: 10 additions & 7 deletions relay/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}

// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
Expand Down Expand Up @@ -304,24 +304,26 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += "(可能是请求出错)"
extraContent += "(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
cacheTokens := usage.PromptTokensDetails.CachedTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName

tokenName := ctx.GetString("token_name")
completionRatio := setting.GetCompletionRatio(modelName)
completionRatio := priceData.CompletionRatio
cacheRatio := priceData.CacheRatio
ratio := priceData.ModelRatio * priceData.GroupRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice

quota := 0
if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = (promptTokens - cacheTokens) + int(math.Round(float64(cacheTokens)*cacheRatio))
quota += int(math.Round(float64(completionTokens) * completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
Expand All @@ -330,8 +332,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens

var logContent string
if !usePrice {
if !priceData.UsePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
Expand Down Expand Up @@ -372,7 +375,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)

Expand Down
5 changes: 3 additions & 2 deletions relay/relay_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)

/*
Expand All @@ -37,9 +38,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}

modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
modelPrice, success := setting.GetModelPrice(modelName, true)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
Expand Down
5 changes: 3 additions & 2 deletions relay/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)

func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
Expand Down Expand Up @@ -39,7 +40,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := setting.GetModelPrice(relayInfo.UpstreamModelName, false)
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)

var preConsumedQuota int
Expand All @@ -65,7 +66,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio, _ = setting.GetModelRatio(relayInfo.UpstreamModelName)
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
Expand Down
12 changes: 8 additions & 4 deletions service/log_info_generate.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package service

import (
"github.com/gin-gonic/gin"
"one-api/dto"
relaycommon "one-api/relay/common"

"github.com/gin-gonic/gin"
)

func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
cacheTokens int, cacheRatio float64, modelPrice float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["cache_tokens"] = cacheTokens
other["cache_ratio"] = cacheRatio
other["model_price"] = modelPrice
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
if relayInfo.ReasoningEffort != "" {
Expand All @@ -27,7 +31,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
}

func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
Expand All @@ -39,7 +43,7 @@ func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
}

func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
info["audio"] = true
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
Expand Down
21 changes: 11 additions & 10 deletions service/quota.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"one-api/setting/operation_setting"
"strings"
"time"

Expand All @@ -38,9 +39,9 @@ func calculateAudioQuota(info QuotaInfo) int {
return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
}

completionRatio := setting.GetCompletionRatio(info.ModelName)
audioRatio := setting.GetAudioRatio(info.ModelName)
audioCompletionRatio := setting.GetAudioCompletionRatio(info.ModelName)
completionRatio := operation_setting.GetCompletionRatio(info.ModelName)
audioRatio := operation_setting.GetAudioRatio(info.ModelName)
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
ratio := info.GroupRatio * info.ModelRatio

quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
Expand Down Expand Up @@ -75,7 +76,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio, _ := setting.GetModelRatio(modelName)
modelRatio, _ := operation_setting.GetModelRatio(modelName)

quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
Expand Down Expand Up @@ -122,9 +123,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens

tokenName := ctx.GetString("token_name")
completionRatio := setting.GetCompletionRatio(modelName)
audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := setting.GetAudioCompletionRatio(modelName)
completionRatio := operation_setting.GetCompletionRatio(modelName)
audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName)

quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
Expand Down Expand Up @@ -184,9 +185,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens

tokenName := ctx.GetString("token_name")
completionRatio := setting.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)

modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
Expand Down
Loading

0 comments on commit 4f194f4

Please sign in to comment.