diff --git a/service/aiproxy/common/fastJSONSerializer/fastJSONSerializer.go b/service/aiproxy/common/fastJSONSerializer/fastJSONSerializer.go index 32fe8e48358..342585626b4 100644 --- a/service/aiproxy/common/fastJSONSerializer/fastJSONSerializer.go +++ b/service/aiproxy/common/fastJSONSerializer/fastJSONSerializer.go @@ -26,6 +26,11 @@ func (*JSONSerializer) Scan(ctx context.Context, field *schema.Field, dst reflec return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) } + if len(bytes) == 0 { + field.ReflectValueOf(ctx, dst).Set(reflect.Zero(field.FieldType)) + return nil + } + err = json.Unmarshal(bytes, fieldValue.Interface()) } diff --git a/service/aiproxy/controller/channel.go b/service/aiproxy/controller/channel.go index c8f194716ac..e9308658d63 100644 --- a/service/aiproxy/controller/channel.go +++ b/service/aiproxy/controller/channel.go @@ -26,24 +26,14 @@ func ChannelTypeMetas(c *gin.Context) { } func GetChannels(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } + page, perPage := parsePageParams(c) id, _ := strconv.Atoi(c.Query("id")) name := c.Query("name") key := c.Query("key") channelType, _ := strconv.Atoi(c.Query("channel_type")) baseURL := c.Query("base_url") order := c.Query("order") - channels, total, err := model.GetChannels(p*perPage, perPage, id, name, key, channelType, baseURL, order) + channels, total, err := model.GetChannels(page*perPage, perPage, id, name, key, channelType, baseURL, order) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return @@ -89,24 +79,14 @@ func AddChannels(c *gin.Context) { func SearchChannels(c *gin.Context) { keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } + page, perPage := parsePageParams(c) id, _ := strconv.Atoi(c.Query("id")) name := c.Query("name") key := c.Query("key") channelType, _ := strconv.Atoi(c.Query("channel_type")) baseURL := c.Query("base_url") order := c.Query("order") - channels, total, err := model.SearchChannels(keyword, p*perPage, perPage, id, name, key, channelType, baseURL, order) + channels, total, err := model.SearchChannels(keyword, page*perPage, perPage, id, name, key, channelType, baseURL, order) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return diff --git a/service/aiproxy/controller/dashboard.go b/service/aiproxy/controller/dashboard.go index 54c95ad7914..10c5573549a 100644 --- a/service/aiproxy/controller/dashboard.go +++ b/service/aiproxy/controller/dashboard.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "strconv" "time" "github.com/gin-gonic/gin" @@ -14,34 +13,42 @@ import ( "github.com/labring/sealos/service/aiproxy/model" ) -func getDashboardTime(t string) (time.Time, time.Time, time.Duration) { +func getDashboardTime(t string) (time.Time, time.Time, model.TimeSpanType) { end := time.Now() var start time.Time - var timeSpan time.Duration + var timeSpan model.TimeSpanType switch t { case "month": start = end.AddDate(0, 0, -30) - timeSpan = time.Hour * 24 + timeSpan = model.TimeSpanDay case "two_week": start = end.AddDate(0, 0, -15) - timeSpan = time.Hour * 24 + timeSpan = model.TimeSpanDay case "week": start = end.AddDate(0, 0, -7) - timeSpan = time.Hour * 24 + timeSpan = model.TimeSpanDay case "day": fallthrough default: start = end.AddDate(0, 0, -1) - timeSpan = time.Hour * 1 + timeSpan = model.TimeSpanHour } return start, end, timeSpan } -func fillGaps(data []*model.ChartData, start, end time.Time, timeSpan time.Duration) []*model.ChartData { +func fillGaps(data []*model.ChartData, start, end time.Time, t model.TimeSpanType) []*model.ChartData { if len(data) == 0 { return data } + var timeSpan time.Duration + switch t { + case model.TimeSpanDay: + timeSpan = time.Hour * 24 + default: + timeSpan = time.Hour + } + // Handle first point firstPoint := time.Unix(data[0].Timestamp, 0) firstAlignedTime := firstPoint @@ -116,27 +123,11 @@ func fillGaps(data []*model.ChartData, start, end time.Time, timeSpan time.Durat return result } -func getTimeSpanWithDefault(c *gin.Context, defaultTimeSpan time.Duration) time.Duration { - spanStr := c.Query("span") - if spanStr == "" { - return defaultTimeSpan - } - span, err := strconv.Atoi(spanStr) - if err != nil { - return defaultTimeSpan - } - if span < 1 || span > 48 { - return defaultTimeSpan - } - return time.Duration(span) * time.Hour -} - func GetDashboard(c *gin.Context) { log := middleware.GetLogger(c) start, end, timeSpan := getDashboardTime(c.Query("type")) modelName := c.Query("model") - timeSpan = getTimeSpanWithDefault(c, timeSpan) dashboards, err := model.GetDashboardData(start, end, modelName, timeSpan) if err != nil { @@ -170,7 +161,6 @@ func GetGroupDashboard(c *gin.Context) { start, end, timeSpan := getDashboardTime(c.Query("type")) tokenName := c.Query("token_name") modelName := c.Query("model") - timeSpan = getTimeSpanWithDefault(c, timeSpan) dashboards, err := model.GetGroupDashboardData(group, start, end, tokenName, modelName, timeSpan) if err != nil { diff --git a/service/aiproxy/controller/group.go b/service/aiproxy/controller/group.go index dd5580ae6a0..bc0d0f16ed8 100644 --- a/service/aiproxy/controller/group.go +++ b/service/aiproxy/controller/group.go @@ -30,20 +30,9 @@ func (g *GroupResponse) MarshalJSON() ([]byte, error) { } func GetGroups(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } - + page, perPage := parsePageParams(c) order := c.DefaultQuery("order", "") - groups, total, err := model.GetGroups(p*perPage, perPage, order, false) + groups, total, err := model.GetGroups(page*perPage, perPage, order, false) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return @@ -64,20 +53,10 @@ func GetGroups(c *gin.Context) { func SearchGroups(c *gin.Context) { keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } + page, perPage := parsePageParams(c) order := c.DefaultQuery("order", "") status, _ := strconv.Atoi(c.Query("status")) - groups, total, err := model.SearchGroup(keyword, p*perPage, perPage, order, status) + groups, total, err := model.SearchGroup(keyword, page*perPage, perPage, order, status) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return diff --git a/service/aiproxy/controller/import.go b/service/aiproxy/controller/import.go new file mode 100644 index 00000000000..a7c70870f11 --- /dev/null +++ b/service/aiproxy/controller/import.go @@ -0,0 +1,209 @@ +package controller + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/labring/sealos/service/aiproxy/middleware" + "github.com/labring/sealos/service/aiproxy/model" + "gorm.io/gorm" +) + +type OneAPIChannel struct { + Type int `json:"type" gorm:"default:0"` + Key string `json:"key" gorm:"type:text"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + BaseURL string `gorm:"column:base_url;default:''"` + Models string `json:"models"` + ModelMapping map[string]string `gorm:"type:varchar(1024);serializer:fastjson"` + Priority int32 `gorm:"bigint;default:0"` + Config ChannelConfig `gorm:"serializer:fastjson"` +} + +func (c *OneAPIChannel) TableName() string { + return "channels" +} + +type ChannelConfig struct { + Region string `json:"region,omitempty"` + SK string `json:"sk,omitempty"` + AK string `json:"ak,omitempty"` + UserID string `json:"user_id,omitempty"` + APIVersion string `json:"api_version,omitempty"` + LibraryID string `json:"library_id,omitempty"` + VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"` + VertexAIADC string `json:"vertex_ai_adc,omitempty"` +} + +// https://github.com/songquanpeng/one-api/blob/main/relay/channeltype/define.go +const ( + OneAPIOpenAI = iota + 1 + OneAPIAPI2D + OneAPIAzure + OneAPICloseAI + OneAPIOpenAISB + OneAPIOpenAIMax + OneAPIOhMyGPT + OneAPICustom + OneAPIAils + OneAPIAIProxy + OneAPIPaLM + OneAPIAPI2GPT + OneAPIAIGC2D + OneAPIAnthropic + OneAPIBaidu + OneAPIZhipu + OneAPIAli + OneAPIXunfei + OneAPIAI360 + OneAPIOpenRouter + OneAPIAIProxyLibrary + OneAPIFastGPT + OneAPITencent + OneAPIGemini + OneAPIMoonshot + OneAPIBaichuan + OneAPIMinimax + OneAPIMistral + OneAPIGroq + OneAPIOllama + OneAPILingYiWanWu + OneAPIStepFun + OneAPIAwsClaude + OneAPICoze + OneAPICohere + OneAPIDeepSeek + OneAPICloudflare + OneAPIDeepL + OneAPITogetherAI + OneAPIDoubao + OneAPINovita + OneAPIVertextAI + OneAPIProxy + OneAPISiliconFlow + OneAPIXAI + OneAPIReplicate + OneAPIBaiduV2 + OneAPIXunfeiV2 + OneAPIAliBailian + OneAPIOpenAICompatible + OneAPIGeminiOpenAICompatible +) + +// relay/channeltype/define.go + +var OneAPIChannelType2AIProxyMap = map[int]int{ + OneAPIOpenAI: 1, + OneAPIAzure: 3, + OneAPIAnthropic: 14, + OneAPIBaidu: 15, + OneAPIZhipu: 16, + OneAPIAli: 17, + OneAPIAI360: 19, + OneAPITencent: 23, + OneAPIGemini: 24, + OneAPIMoonshot: 25, + OneAPIBaichuan: 26, + OneAPIMinimax: 27, + OneAPIMistral: 28, + OneAPIGroq: 29, + OneAPIOllama: 30, + OneAPILingYiWanWu: 31, + OneAPIStepFun: 32, + OneAPIAwsClaude: 33, + OneAPICoze: 34, + OneAPICohere: 35, + OneAPIDeepSeek: 36, + OneAPICloudflare: 37, + OneAPIDoubao: 40, + OneAPINovita: 41, + OneAPIVertextAI: 42, + OneAPISiliconFlow: 43, + OneAPIBaiduV2: 13, + OneAPIXunfeiV2: 18, + OneAPIAliBailian: 17, + OneAPIGeminiOpenAICompatible: 12, +} + +type ImportChannelFromOneAPIRequest struct { + DSN string `json:"dsn"` +} + +func AddOneAPIChannel(ch OneAPIChannel) error { + add := AddChannelRequest{ + Type: ch.Type, + Name: ch.Name, + Key: ch.Key, + BaseURL: ch.BaseURL, + Models: strings.Split(ch.Models, ","), + ModelMapping: ch.ModelMapping, + Priority: ch.Priority, + Status: ch.Status, + } + if t, ok := OneAPIChannelType2AIProxyMap[ch.Type]; ok { + add.Type = t + } else { + add.Type = 1 + } + if add.Type == 1 && add.BaseURL != "" { + add.BaseURL += "/v1" + } + chs, err := add.ToChannels() + if err != nil { + return err + } + return model.BatchInsertChannels(chs) +} + +func ImportChannelFromOneAPI(c *gin.Context) { + var req ImportChannelFromOneAPIRequest + if err := c.ShouldBindJSON(&req); err != nil { + middleware.ErrorResponse(c, http.StatusBadRequest, err.Error()) + return + } + + if req.DSN == "" { + middleware.ErrorResponse(c, http.StatusBadRequest, "sql dsn is required") + return + } + + var db *gorm.DB + var err error + if strings.HasPrefix(req.DSN, "mysql") { + db, err = model.OpenMySQL(req.DSN) + } else if strings.HasPrefix(req.DSN, "postgres") { + db, err = model.OpenPostgreSQL(req.DSN) + } else { + middleware.ErrorResponse(c, http.StatusBadRequest, "invalid dsn, only mysql and postgres are supported") + return + } + if err != nil { + middleware.ErrorResponse(c, http.StatusBadRequest, err.Error()) + return + } + sqlDB, err := db.DB() + if err != nil { + middleware.ErrorResponse(c, http.StatusBadRequest, err.Error()) + return + } + defer sqlDB.Close() + + allChannels := make([]*OneAPIChannel, 0) + err = db.Model(&OneAPIChannel{}).Find(&allChannels).Error + if err != nil { + middleware.ErrorResponse(c, http.StatusBadRequest, err.Error()) + return + } + + errs := make([]error, 0) + for _, ch := range allChannels { + err := AddOneAPIChannel(*ch) + if err != nil { + errs = append(errs, err) + } + } + + middleware.SuccessResponse(c, errs) +} diff --git a/service/aiproxy/controller/log.go b/service/aiproxy/controller/log.go index f028a9099be..47f53484fc6 100644 --- a/service/aiproxy/controller/log.go +++ b/service/aiproxy/controller/log.go @@ -10,56 +10,75 @@ import ( "github.com/labring/sealos/service/aiproxy/model" ) -func GetLogs(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } +func parseTimeRange(c *gin.Context) (startTime, endTime time.Time) { startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) - var startTimestampTime time.Time + if startTimestamp != 0 { - startTimestampTime = time.UnixMilli(startTimestamp) + startTime = time.UnixMilli(startTimestamp) } - var endTimestampTime time.Time + sevenDaysAgo := time.Now().AddDate(0, 0, -7) + if startTime.IsZero() || startTime.Before(sevenDaysAgo) { + startTime = sevenDaysAgo + } + if endTimestamp != 0 { - endTimestampTime = time.UnixMilli(endTimestamp) + endTime = time.UnixMilli(endTimestamp) } - tokenName := c.Query("token_name") - modelName := c.Query("model_name") - channelID, _ := strconv.Atoi(c.Query("channel")) + return +} + +func parseCommonParams(c *gin.Context) (params struct { + tokenName string + modelName string + channelID int + endpoint string + tokenID int + order string + requestID string + mode int + codeType string + withBody bool + ip string +}, +) { + params.tokenName = c.Query("token_name") + params.modelName = c.Query("model_name") + params.channelID, _ = strconv.Atoi(c.Query("channel")) + params.endpoint = c.Query("endpoint") + params.tokenID, _ = strconv.Atoi(c.Query("token_id")) + params.order = c.Query("order") + params.requestID = c.Query("request_id") + params.mode, _ = strconv.Atoi(c.Query("mode")) + params.codeType = c.Query("code_type") + params.withBody, _ = strconv.ParseBool(c.Query("with_body")) + params.ip = c.Query("ip") + return +} + +// Handler functions +func GetLogs(c *gin.Context) { + page, perPage := parsePageParams(c) + startTime, endTime := parseTimeRange(c) + params := parseCommonParams(c) group := c.Query("group") - endpoint := c.Query("endpoint") - tokenID, _ := strconv.Atoi(c.Query("token_id")) - order := c.Query("order") - requestID := c.Query("request_id") - mode, _ := strconv.Atoi(c.Query("mode")) - codeType := c.Query("code_type") - withBody, _ := strconv.ParseBool(c.Query("with_body")) - ip := c.Query("ip") + result, err := model.GetLogs( group, - startTimestampTime, - endTimestampTime, - modelName, - requestID, - tokenID, - tokenName, - channelID, - endpoint, - order, - mode, - model.CodeType(codeType), - withBody, - ip, - p, + startTime, + endTime, + params.modelName, + params.requestID, + params.tokenID, + params.tokenName, + params.channelID, + params.endpoint, + params.order, + params.mode, + model.CodeType(params.codeType), + params.withBody, + params.ip, + page, perPage, ) if err != nil { @@ -75,54 +94,27 @@ func GetGroupLogs(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, "group is required") return } - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } - startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) - endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) - var startTimestampTime time.Time - if startTimestamp != 0 { - startTimestampTime = time.UnixMilli(startTimestamp) - } - var endTimestampTime time.Time - if endTimestamp != 0 { - endTimestampTime = time.UnixMilli(endTimestamp) - } - tokenName := c.Query("token_name") - modelName := c.Query("model_name") - channelID, _ := strconv.Atoi(c.Query("channel")) - endpoint := c.Query("endpoint") - tokenID, _ := strconv.Atoi(c.Query("token_id")) - order := c.Query("order") - requestID := c.Query("request_id") - mode, _ := strconv.Atoi(c.Query("mode")) - codeType := c.Query("code_type") - withBody, _ := strconv.ParseBool(c.Query("with_body")) - ip := c.Query("ip") + + page, perPage := parsePageParams(c) + startTime, endTime := parseTimeRange(c) + params := parseCommonParams(c) + result, err := model.GetGroupLogs( group, - startTimestampTime, - endTimestampTime, - modelName, - requestID, - tokenID, - tokenName, - channelID, - endpoint, - order, - mode, - model.CodeType(codeType), - withBody, - ip, - p, + startTime, + endTime, + params.modelName, + params.requestID, + params.tokenID, + params.tokenName, + params.channelID, + params.endpoint, + params.order, + params.mode, + model.CodeType(params.codeType), + params.withBody, + params.ip, + page, perPage, ) if err != nil { @@ -133,53 +125,30 @@ func GetGroupLogs(c *gin.Context) { } func SearchLogs(c *gin.Context) { + page, perPage := parsePageParams(c) + startTime, endTime := parseTimeRange(c) + params := parseCommonParams(c) + keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } - endpoint := c.Query("endpoint") - tokenName := c.Query("token_name") - modelName := c.Query("model_name") group := c.Query("group_id") - tokenID, _ := strconv.Atoi(c.Query("token_id")) - channelID, _ := strconv.Atoi(c.Query("channel")) - startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) - endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) - var startTimestampTime time.Time - if startTimestamp != 0 { - startTimestampTime = time.UnixMilli(startTimestamp) - } - var endTimestampTime time.Time - if endTimestamp != 0 { - endTimestampTime = time.UnixMilli(endTimestamp) - } - order := c.Query("order") - requestID := c.Query("request_id") - mode, _ := strconv.Atoi(c.Query("mode")) - codeType := c.Query("code_type") - withBody, _ := strconv.ParseBool(c.Query("with_body")) - ip := c.Query("ip") + result, err := model.SearchLogs( group, keyword, - endpoint, - requestID, - tokenID, - tokenName, - modelName, - startTimestampTime, - endTimestampTime, - channelID, - order, - mode, - model.CodeType(codeType), - withBody, - ip, - p, + params.endpoint, + params.requestID, + params.tokenID, + params.tokenName, + params.modelName, + startTime, + endTime, + params.channelID, + params.order, + params.mode, + model.CodeType(params.codeType), + params.withBody, + params.ip, + page, perPage, ) if err != nil { @@ -195,52 +164,29 @@ func SearchGroupLogs(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, "group is required") return } + + page, perPage := parsePageParams(c) + startTime, endTime := parseTimeRange(c) + params := parseCommonParams(c) keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } - endpoint := c.Query("endpoint") - tokenName := c.Query("token_name") - modelName := c.Query("model_name") - tokenID, _ := strconv.Atoi(c.Query("token_id")) - channelID, _ := strconv.Atoi(c.Query("channel")) - startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) - endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) - var startTimestampTime time.Time - if startTimestamp != 0 { - startTimestampTime = time.UnixMilli(startTimestamp) - } - var endTimestampTime time.Time - if endTimestamp != 0 { - endTimestampTime = time.UnixMilli(endTimestamp) - } - order := c.Query("order") - requestID := c.Query("request_id") - mode, _ := strconv.Atoi(c.Query("mode")) - codeType := c.Query("code_type") - withBody, _ := strconv.ParseBool(c.Query("with_body")) - ip := c.Query("ip") + result, err := model.SearchGroupLogs( group, keyword, - endpoint, - requestID, - tokenID, - tokenName, - modelName, - startTimestampTime, - endTimestampTime, - channelID, - order, - mode, - model.CodeType(codeType), - withBody, - ip, - p, + params.endpoint, + params.requestID, + params.tokenID, + params.tokenName, + params.modelName, + startTime, + endTime, + params.channelID, + params.order, + params.mode, + model.CodeType(params.codeType), + params.withBody, + params.ip, + page, perPage, ) if err != nil { @@ -297,6 +243,7 @@ func SearchConsumeError(c *gin.Context) { content := c.Query("content") tokenID, _ := strconv.Atoi(c.Query("token_id")) usedAmount, _ := strconv.ParseFloat(c.Query("used_amount"), 64) + page, _ := strconv.Atoi(c.Query("page")) perPage, _ := strconv.Atoi(c.Query("per_page")) if perPage <= 0 { @@ -304,9 +251,23 @@ func SearchConsumeError(c *gin.Context) { } else if perPage > 100 { perPage = 100 } + order := c.Query("order") requestID := c.Query("request_id") - logs, total, err := model.SearchConsumeError(keyword, requestID, group, tokenName, modelName, content, usedAmount, tokenID, page, perPage, order) + + logs, total, err := model.SearchConsumeError( + keyword, + requestID, + group, + tokenName, + modelName, + content, + usedAmount, + tokenID, + page, + perPage, + order, + ) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return diff --git a/service/aiproxy/controller/modelconfig.go b/service/aiproxy/controller/modelconfig.go index b79d130a70a..9925800d674 100644 --- a/service/aiproxy/controller/modelconfig.go +++ b/service/aiproxy/controller/modelconfig.go @@ -2,7 +2,6 @@ package controller import ( "net/http" - "strconv" "github.com/gin-gonic/gin" "github.com/labring/sealos/service/aiproxy/middleware" @@ -10,19 +9,9 @@ import ( ) func GetModelConfigs(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } + page, perPage := parsePageParams(c) _model := c.Query("model") - configs, total, err := model.GetModelConfigs(p*perPage, perPage, _model) + configs, total, err := model.GetModelConfigs(page*perPage, perPage, _model) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return @@ -63,20 +52,10 @@ func GetModelConfigsByModelsContains(c *gin.Context) { func SearchModelConfigs(c *gin.Context) { keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } + page, perPage := parsePageParams(c) _model := c.Query("model") owner := c.Query("owner") - configs, total, err := model.SearchModelConfigs(keyword, p*perPage, perPage, _model, model.ModelOwner(owner)) + configs, total, err := model.SearchModelConfigs(keyword, page*perPage, perPage, _model, model.ModelOwner(owner)) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return diff --git a/service/aiproxy/controller/monitor.go b/service/aiproxy/controller/monitor.go index fbe4be466d0..5e555466839 100644 --- a/service/aiproxy/controller/monitor.go +++ b/service/aiproxy/controller/monitor.go @@ -72,3 +72,12 @@ func ClearChannelModelErrors(c *gin.Context) { } c.Status(http.StatusNoContent) } + +func GetModelsErrorRate(c *gin.Context) { + rates, err := monitor.GetModelsErrorRate(c.Request.Context()) + if err != nil { + middleware.ErrorResponse(c, http.StatusOK, err.Error()) + return + } + c.JSON(http.StatusOK, rates) +} diff --git a/service/aiproxy/controller/token.go b/service/aiproxy/controller/token.go index abdf52815aa..19aaf4b15f3 100644 --- a/service/aiproxy/controller/token.go +++ b/service/aiproxy/controller/token.go @@ -15,6 +15,7 @@ import ( "github.com/labring/sealos/service/aiproxy/model" ) +// TokenResponse represents the response structure for token endpoints type TokenResponse struct { *model.Token AccessedAt time.Time `json:"accessed_at"` @@ -35,36 +36,80 @@ func (t *TokenResponse) MarshalJSON() ([]byte, error) { }) } -func GetTokens(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 +type ( + AddTokenRequest struct { + Name string `json:"name"` + Subnet string `json:"subnet"` + Models []string `json:"models"` + ExpiredAt int64 `json:"expiredAt"` + Quota float64 `json:"quota"` + } + + UpdateTokenStatusRequest struct { + Status int `json:"status"` + } + + UpdateTokenNameRequest struct { + Name string `json:"name"` + } +) + +func validateToken(token AddTokenRequest) error { + if token.Name == "" { + return errors.New("token name cannot be empty") + } + if len(token.Name) > 30 { + return errors.New("token name is too long") + } + if token.Subnet != "" { + if err := network.IsValidSubnets(token.Subnet); err != nil { + return fmt.Errorf("invalid subnet: %w", err) + } + } + return nil +} + +func validateTokenStatus(token *model.Token) error { + if token.Status == model.TokenStatusExpired && !token.ExpiredAt.IsZero() && token.ExpiredAt.Before(time.Now()) { + return errors.New("token expired, please update token expired time or set to never expire") + } + if token.Status == model.TokenStatusExhausted && token.Quota > 0 && token.UsedAmount >= token.Quota { + return errors.New("token quota exhausted, please update token quota or set to unlimited quota") + } + return nil +} + +func buildTokenResponse(token *model.Token) *TokenResponse { + lastRequestAt, _ := model.GetTokenLastRequestTime(token.ID) + return &TokenResponse{ + Token: token, + AccessedAt: lastRequestAt, } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 +} + +func buildTokenResponses(tokens []*model.Token) []*TokenResponse { + responses := make([]*TokenResponse, len(tokens)) + for i, token := range tokens { + responses[i] = buildTokenResponse(token) } + return responses +} + +// Token list handlers +func GetTokens(c *gin.Context) { + page, perPage := parsePageParams(c) group := c.Query("group") order := c.Query("order") status, _ := strconv.Atoi(c.Query("status")) - tokens, total, err := model.GetTokens(group, p*perPage, perPage, order, status) + + tokens, total, err := model.GetTokens(group, page*perPage, perPage, order, status) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - tokenResponses := make([]*TokenResponse, len(tokens)) - for i, token := range tokens { - lastRequestAt, _ := model.GetTokenLastRequestTime(token.ID) - tokenResponses[i] = &TokenResponse{ - Token: token, - AccessedAt: lastRequestAt, - } - } + middleware.SuccessResponse(c, gin.H{ - "tokens": tokenResponses, + "tokens": buildTokenResponses(tokens), "total": total, }) } @@ -75,71 +120,40 @@ func GetGroupTokens(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, "group is required") return } - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } + + page, perPage := parsePageParams(c) order := c.Query("order") status, _ := strconv.Atoi(c.Query("status")) - tokens, total, err := model.GetTokens(group, p*perPage, perPage, order, status) + + tokens, total, err := model.GetTokens(group, page*perPage, perPage, order, status) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - tokenResponses := make([]*TokenResponse, len(tokens)) - for i, token := range tokens { - lastRequestAt, _ := model.GetTokenLastRequestTime(token.ID) - tokenResponses[i] = &TokenResponse{ - Token: token, - AccessedAt: lastRequestAt, - } - } + middleware.SuccessResponse(c, gin.H{ - "tokens": tokenResponses, + "tokens": buildTokenResponses(tokens), "total": total, }) } func SearchTokens(c *gin.Context) { + page, perPage := parsePageParams(c) keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } order := c.Query("order") name := c.Query("name") key := c.Query("key") status, _ := strconv.Atoi(c.Query("status")) group := c.Query("group") - tokens, total, err := model.SearchTokens(group, keyword, p*perPage, perPage, order, status, name, key) + + tokens, total, err := model.SearchTokens(group, keyword, page*perPage, perPage, order, status, name, key) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - tokenResponses := make([]*TokenResponse, len(tokens)) - for i, token := range tokens { - lastRequestAt, _ := model.GetTokenLastRequestTime(token.ID) - tokenResponses[i] = &TokenResponse{ - Token: token, - AccessedAt: lastRequestAt, - } - } + middleware.SuccessResponse(c, gin.H{ - "tokens": tokenResponses, + "tokens": buildTokenResponses(tokens), "total": total, }) } @@ -150,58 +164,41 @@ func SearchGroupTokens(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, "group is required") return } + + page, perPage := parsePageParams(c) keyword := c.Query("keyword") - p, _ := strconv.Atoi(c.Query("p")) - p-- - if p < 0 { - p = 0 - } - perPage, _ := strconv.Atoi(c.Query("per_page")) - if perPage <= 0 { - perPage = 10 - } else if perPage > 100 { - perPage = 100 - } order := c.Query("order") name := c.Query("name") key := c.Query("key") status, _ := strconv.Atoi(c.Query("status")) - tokens, total, err := model.SearchTokens(group, keyword, p*perPage, perPage, order, status, name, key) + + tokens, total, err := model.SearchTokens(group, keyword, page*perPage, perPage, order, status, name, key) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - tokenResponses := make([]*TokenResponse, len(tokens)) - for i, token := range tokens { - lastRequestAt, _ := model.GetTokenLastRequestTime(token.ID) - tokenResponses[i] = &TokenResponse{ - Token: token, - AccessedAt: lastRequestAt, - } - } + middleware.SuccessResponse(c, gin.H{ - "tokens": tokenResponses, + "tokens": buildTokenResponses(tokens), "total": total, }) } +// Single token handlers func GetToken(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + token, err := model.GetTokenByID(id) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - lastRequestAt, _ := model.GetTokenLastRequestTime(id) - tokenResponse := &TokenResponse{ - Token: token, - AccessedAt: lastRequestAt, - } - middleware.SuccessResponse(c, tokenResponse) + + middleware.SuccessResponse(c, buildTokenResponse(token)) } func GetGroupToken(c *gin.Context) { @@ -210,114 +207,86 @@ func GetGroupToken(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, "group is required") return } + id, err := strconv.Atoi(c.Param("id")) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + token, err := model.GetGroupTokenByID(group, id) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - lastRequestAt, _ := model.GetTokenLastRequestTime(id) - tokenResponse := &TokenResponse{ - Token: token, - AccessedAt: lastRequestAt, - } - middleware.SuccessResponse(c, tokenResponse) -} - -func validateToken(token AddTokenRequest) error { - if token.Name == "" { - return errors.New("token name cannot be empty") - } - if len(token.Name) > 30 { - return errors.New("token name is too long") - } - if token.Subnet != "" { - err := network.IsValidSubnets(token.Subnet) - if err != nil { - return fmt.Errorf("invalid subnet: %w", err) - } - } - return nil -} -type AddTokenRequest struct { - Name string `json:"name"` - Subnet string `json:"subnet"` - Models []string `json:"models"` - ExpiredAt int64 `json:"expiredAt"` - Quota float64 `json:"quota"` + middleware.SuccessResponse(c, buildTokenResponse(token)) } func AddToken(c *gin.Context) { group := c.Param("group") - token := AddTokenRequest{} - err := c.ShouldBindJSON(&token) - if err != nil { + var req AddTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = validateToken(token) - if err != nil { + + if err := validateToken(req); err != nil { middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error()) return } var expiredAt time.Time - if token.ExpiredAt == 0 { - expiredAt = time.Time{} - } else { - expiredAt = time.UnixMilli(token.ExpiredAt) + if req.ExpiredAt > 0 { + expiredAt = time.UnixMilli(req.ExpiredAt) } - cleanToken := &model.Token{ + token := &model.Token{ GroupID: group, - Name: model.EmptyNullString(token.Name), + Name: model.EmptyNullString(req.Name), Key: random.GenerateKey(), ExpiredAt: expiredAt, - Quota: token.Quota, - Models: token.Models, - Subnet: token.Subnet, + Quota: req.Quota, + Models: req.Models, + Subnet: req.Subnet, } - err = model.InsertToken(cleanToken, c.Query("auto_create_group") == "true") - if err != nil { + + if err := model.InsertToken(token, c.Query("auto_create_group") == "true"); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - middleware.SuccessResponse(c, &TokenResponse{ - Token: cleanToken, - }) + + middleware.SuccessResponse(c, &TokenResponse{Token: token}) } +// Delete handlers func DeleteToken(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = model.DeleteTokenByID(id) - if err != nil { + + if err := model.DeleteTokenByID(id); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + middleware.SuccessResponse(c, nil) } func DeleteTokens(c *gin.Context) { - ids := []int{} - err := c.ShouldBindJSON(&ids) - if err != nil { + var ids []int + if err := c.ShouldBindJSON(&ids); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = model.DeleteTokensByIDs(ids) - if err != nil { + + if err := model.DeleteTokensByIDs(ids); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + middleware.SuccessResponse(c, nil) } @@ -328,69 +297,73 @@ func DeleteGroupToken(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = model.DeleteGroupTokenByID(group, id) - if err != nil { + + if err := model.DeleteGroupTokenByID(group, id); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + middleware.SuccessResponse(c, nil) } func DeleteGroupTokens(c *gin.Context) { group := c.Param("group") - ids := []int{} - err := c.ShouldBindJSON(&ids) - if err != nil { + var ids []int + if err := c.ShouldBindJSON(&ids); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = model.DeleteGroupTokensByIDs(group, ids) - if err != nil { + + if err := model.DeleteGroupTokensByIDs(group, ids); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + middleware.SuccessResponse(c, nil) } +// Update handlers func UpdateToken(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - token := AddTokenRequest{} - err = c.ShouldBindJSON(&token) - if err != nil { + + var req AddTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = validateToken(token) - if err != nil { + + if err := validateToken(req); err != nil { middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error()) return } - cleanToken, err := model.GetTokenByID(id) + + token, err := model.GetTokenByID(id) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - expiredAt := time.Time{} - if token.ExpiredAt != 0 { - expiredAt = time.UnixMilli(token.ExpiredAt) + + var expiredAt time.Time + if req.ExpiredAt > 0 { + expiredAt = time.UnixMilli(req.ExpiredAt) } - cleanToken.Name = model.EmptyNullString(token.Name) - cleanToken.ExpiredAt = expiredAt - cleanToken.Quota = token.Quota - cleanToken.Models = token.Models - cleanToken.Subnet = token.Subnet - err = model.UpdateToken(cleanToken) - if err != nil { + + token.Name = model.EmptyNullString(req.Name) + token.ExpiredAt = expiredAt + token.Quota = req.Quota + token.Models = req.Models + token.Subnet = req.Subnet + + if err := model.UpdateToken(token); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - middleware.SuccessResponse(c, &TokenResponse{ - Token: cleanToken, - }) + + middleware.SuccessResponse(c, &TokenResponse{Token: token}) } func UpdateGroupToken(c *gin.Context) { @@ -400,43 +373,41 @@ func UpdateGroupToken(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - token := AddTokenRequest{} - err = c.ShouldBindJSON(&token) - if err != nil { + + var req AddTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = validateToken(token) - if err != nil { + + if err := validateToken(req); err != nil { middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error()) return } - cleanToken, err := model.GetGroupTokenByID(group, id) + + token, err := model.GetGroupTokenByID(group, id) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - expiredAt := time.Time{} - if token.ExpiredAt != 0 { - expiredAt = time.UnixMilli(token.ExpiredAt) + + var expiredAt time.Time + if req.ExpiredAt > 0 { + expiredAt = time.UnixMilli(req.ExpiredAt) } - cleanToken.Name = model.EmptyNullString(token.Name) - cleanToken.ExpiredAt = expiredAt - cleanToken.Quota = token.Quota - cleanToken.Models = token.Models - cleanToken.Subnet = token.Subnet - err = model.UpdateToken(cleanToken) - if err != nil { + + token.Name = model.EmptyNullString(req.Name) + token.ExpiredAt = expiredAt + token.Quota = req.Quota + token.Models = req.Models + token.Subnet = req.Subnet + + if err := model.UpdateToken(token); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - middleware.SuccessResponse(c, &TokenResponse{ - Token: cleanToken, - }) -} -type UpdateTokenStatusRequest struct { - Status int `json:"status"` + middleware.SuccessResponse(c, &TokenResponse{Token: token}) } func UpdateTokenStatus(c *gin.Context) { @@ -445,35 +416,32 @@ func UpdateTokenStatus(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - token := UpdateTokenStatusRequest{} - err = c.ShouldBindJSON(&token) - if err != nil { + + var req UpdateTokenStatusRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - cleanToken, err := model.GetTokenByID(id) + + token, err := model.GetTokenByID(id) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - if token.Status == model.TokenStatusEnabled { - if err := validateTokenStatus(cleanToken); err != nil { + if req.Status == model.TokenStatusEnabled { + if err := validateTokenStatus(token); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } } - err = model.UpdateTokenStatus(id, token.Status) - if err != nil { + if err := model.UpdateTokenStatus(id, req.Status); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - middleware.SuccessResponse(c, nil) -} -type UpdateGroupTokenStatusRequest struct { - UpdateTokenStatusRequest + middleware.SuccessResponse(c, nil) } func UpdateGroupTokenStatus(c *gin.Context) { @@ -483,45 +451,32 @@ func UpdateGroupTokenStatus(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - token := UpdateTokenStatusRequest{} - err = c.ShouldBindJSON(&token) - if err != nil { + + var req UpdateTokenStatusRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - cleanToken, err := model.GetGroupTokenByID(group, id) + + token, err := model.GetGroupTokenByID(group, id) if err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - if token.Status == model.TokenStatusEnabled { - if err := validateTokenStatus(cleanToken); err != nil { + if req.Status == model.TokenStatusEnabled { + if err := validateTokenStatus(token); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } } - err = model.UpdateGroupTokenStatus(group, id, token.Status) - if err != nil { + if err := model.UpdateGroupTokenStatus(group, id, req.Status); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - middleware.SuccessResponse(c, nil) -} - -func validateTokenStatus(token *model.Token) error { - if token.Status == model.TokenStatusExpired && !token.ExpiredAt.IsZero() && token.ExpiredAt.Before(time.Now()) { - return errors.New("token expired, please update token expired time or set to never expire") - } - if token.Status == model.TokenStatusExhausted && token.Quota > 0 && token.UsedAmount >= token.Quota { - return errors.New("token quota exhausted, please update token quota or set to unlimited quota") - } - return nil -} -type UpdateTokenNameRequest struct { - Name string `json:"name"` + middleware.SuccessResponse(c, nil) } func UpdateTokenName(c *gin.Context) { @@ -530,17 +485,18 @@ func UpdateTokenName(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - name := UpdateTokenNameRequest{} - err = c.ShouldBindJSON(&name) - if err != nil { + + var req UpdateTokenNameRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = model.UpdateTokenName(id, name.Name) - if err != nil { + + if err := model.UpdateTokenName(id, req.Name); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + middleware.SuccessResponse(c, nil) } @@ -551,16 +507,17 @@ func UpdateGroupTokenName(c *gin.Context) { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - name := UpdateTokenNameRequest{} - err = c.ShouldBindJSON(&name) - if err != nil { + + var req UpdateTokenNameRequest + if err := c.ShouldBindJSON(&req); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } - err = model.UpdateGroupTokenName(group, id, name.Name) - if err != nil { + + if err := model.UpdateGroupTokenName(group, id, req.Name); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } + middleware.SuccessResponse(c, nil) } diff --git a/service/aiproxy/controller/utils.go b/service/aiproxy/controller/utils.go new file mode 100644 index 00000000000..f66c5f31376 --- /dev/null +++ b/service/aiproxy/controller/utils.go @@ -0,0 +1,23 @@ +package controller + +import ( + "strconv" + + "github.com/gin-gonic/gin" +) + +func parsePageParams(c *gin.Context) (page, perPage int) { + page, _ = strconv.Atoi(c.Query("p")) + page-- + if page < 0 { + page = 0 + } + + perPage, _ = strconv.Atoi(c.Query("per_page")) + if perPage <= 0 { + perPage = 10 + } else if perPage > 100 { + perPage = 100 + } + return +} diff --git a/service/aiproxy/middleware/auth.go b/service/aiproxy/middleware/auth.go index 2390d11e881..4d00b34133b 100644 --- a/service/aiproxy/middleware/auth.go +++ b/service/aiproxy/middleware/auth.go @@ -58,8 +58,6 @@ func TokenAuth(c *gin.Context) { strings.TrimPrefix(key, "Bearer "), "sk-", ) - parts := strings.Split(key, "-") - key = parts[0] var token *model.TokenCache var useInternalToken bool diff --git a/service/aiproxy/model/log.go b/service/aiproxy/model/log.go index 18cdcba0a60..d220deaf2d6 100644 --- a/service/aiproxy/model/log.go +++ b/service/aiproxy/model/log.go @@ -23,13 +23,13 @@ const ( ) type RequestDetail struct { - CreatedAt time.Time `gorm:"autoCreateTime" json:"-"` + CreatedAt time.Time `gorm:"autoCreateTime;index" json:"-"` RequestBody string `gorm:"type:text" json:"request_body,omitempty"` ResponseBody string `gorm:"type:text" json:"response_body,omitempty"` RequestBodyTruncated bool `json:"request_body_truncated,omitempty"` ResponseBodyTruncated bool `json:"response_body_truncated,omitempty"` ID int `json:"id"` - LogID int `json:"log_id"` + LogID int `gorm:"index" json:"log_id"` } func (d *RequestDetail) BeforeSave(_ *gorm.DB) (err error) { @@ -45,27 +45,89 @@ func (d *RequestDetail) BeforeSave(_ *gorm.DB) (err error) { } type Log struct { - RequestDetail *RequestDetail `gorm:"foreignKey:LogID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"request_detail,omitempty"` - RequestAt time.Time `gorm:"index;index:idx_request_at_group_id,priority:2;index:idx_group_reqat_token,priority:2" json:"request_at"` - CreatedAt time.Time `gorm:"index" json:"created_at"` - TokenName string `gorm:"index;index:idx_group_token,priority:2;index:idx_group_reqat_token,priority:3" json:"token_name,omitempty"` - Endpoint string `gorm:"index" json:"endpoint"` - Content string `gorm:"type:text" json:"content,omitempty"` - GroupID string `gorm:"index;index:idx_group_token,priority:1;index:idx_request_at_group_id,priority:1;index:idx_group_reqat_token,priority:1" json:"group,omitempty"` - Model string `gorm:"index" json:"model"` - RequestID string `gorm:"index" json:"request_id"` - Price float64 `json:"price"` - ID int `gorm:"primaryKey" json:"id"` - CompletionPrice float64 `json:"completion_price"` - TokenID int `gorm:"index" json:"token_id,omitempty"` - UsedAmount float64 `gorm:"index" json:"used_amount"` - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - ChannelID int `gorm:"index" json:"channel"` - Code int `gorm:"index" json:"code"` - Mode int `json:"mode"` - IP string `json:"ip"` + RequestDetail *RequestDetail `gorm:"foreignKey:LogID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"request_detail,omitempty"` + RequestAt time.Time `gorm:"index" json:"request_at"` + TimestampTruncByDay int64 `json:"timestamp_trunc_by_day"` + TimestampTruncByHour int64 `json:"timestamp_trunc_by_hour"` + CreatedAt time.Time `json:"created_at"` + TokenName string `json:"token_name,omitempty"` + Endpoint string `json:"endpoint"` + Content string `gorm:"type:text" json:"content,omitempty"` + GroupID string `gorm:"index" json:"group,omitempty"` + Model string `gorm:"index" json:"model"` + RequestID string `gorm:"index" json:"request_id"` + Price float64 `json:"price"` + ID int `gorm:"primaryKey" json:"id"` + CompletionPrice float64 `json:"completion_price"` + TokenID int `gorm:"index" json:"token_id,omitempty"` + UsedAmount float64 `json:"used_amount"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + ChannelID int `gorm:"index" json:"channel"` + Code int `gorm:"index" json:"code"` + Mode int `json:"mode"` + IP string `gorm:"index" json:"ip"` +} + +func CreateLogIndexes(db *gorm.DB) error { + var indexes []string + if common.UsingSQLite { + // not support INCLUDE + indexes = []string{ + // used by search group logs + "CREATE INDEX IF NOT EXISTS idx_group_token_reqat ON logs (group_id, token_name, request_at)", + // used by search group logs + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat ON logs (group_id, model, request_at)", + // used by group used tokens + "CREATE INDEX IF NOT EXISTS idx_group_reqat_token ON logs (group_id, request_at, token_name)", + // used by group used models + "CREATE INDEX IF NOT EXISTS idx_group_reqat_model ON logs (group_id, request_at, model)", + // used by search group logs + "CREATE INDEX IF NOT EXISTS idx_group_token_model_reqat ON logs (group_id, token_name, model, request_at)", + + // day indexes, used by dashboard + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat_truncday ON logs (group_id, model, request_at, timestamp_trunc_by_day)", + "CREATE INDEX IF NOT EXISTS idx_group_token_reqat_truncday ON logs (group_id, token_name, request_at, timestamp_trunc_by_day)", + "CREATE INDEX IF NOT EXISTS idx_group_model_token_reqat_truncday ON logs (group_id, model, token_name, request_at, timestamp_trunc_by_day)", + // hour indexes, used by dashboard + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat_trunchour ON logs (group_id, model, request_at, timestamp_trunc_by_hour)", + "CREATE INDEX IF NOT EXISTS idx_group_token_reqat_trunchour ON logs (group_id, token_name, request_at, timestamp_trunc_by_hour)", + "CREATE INDEX IF NOT EXISTS idx_group_model_token_reqat_trunchour ON logs (group_id, model, token_name, request_at, timestamp_trunc_by_hour)", + } + } else { + indexes = []string{ + // used by search group logs + "CREATE INDEX IF NOT EXISTS idx_group_token_reqat ON logs (group_id, token_name, request_at) INCLUDE (code)", + // used by search group logs + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat ON logs (group_id, model, request_at) INCLUDE (code)", + // used by group used tokens + "CREATE INDEX IF NOT EXISTS idx_group_reqat_token ON logs (group_id, request_at, token_name)", + // used by group used models + "CREATE INDEX IF NOT EXISTS idx_group_reqat_model ON logs (group_id, request_at, model)", + // used by search group logs + "CREATE INDEX IF NOT EXISTS idx_group_token_model_reqat ON logs (group_id, token_name, model, request_at) INCLUDE (code)", + + // day indexes, used by dashboard + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat_truncday ON logs (group_id, request_at, timestamp_trunc_by_day) INCLUDE (code, used_amount, total_tokens)", + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat_truncday ON logs (group_id, model, request_at, timestamp_trunc_by_day) INCLUDE (code, used_amount, total_tokens)", + "CREATE INDEX IF NOT EXISTS idx_group_token_reqat_truncday ON logs (group_id, token_name, request_at, timestamp_trunc_by_day) INCLUDE (code, used_amount, total_tokens)", + "CREATE INDEX IF NOT EXISTS idx_group_model_token_reqat_truncday ON logs (group_id, model, token_name, request_at, timestamp_trunc_by_day) INCLUDE (code, used_amount, total_tokens)", + // hour indexes, used by dashboard + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat_trunchour ON logs (group_id, request_at, timestamp_trunc_by_hour) INCLUDE (code, used_amount, total_tokens)", + "CREATE INDEX IF NOT EXISTS idx_group_model_reqat_trunchour ON logs (group_id, model, request_at, timestamp_trunc_by_hour) INCLUDE (code, used_amount, total_tokens)", + "CREATE INDEX IF NOT EXISTS idx_group_token_reqat_trunchour ON logs (group_id, token_name, request_at, timestamp_trunc_by_hour) INCLUDE (code, used_amount, total_tokens)", + "CREATE INDEX IF NOT EXISTS idx_group_model_token_reqat_trunchour ON logs (group_id, model, token_name, request_at, timestamp_trunc_by_hour) INCLUDE (code, used_amount, total_tokens)", + } + } + + for _, index := range indexes { + if err := db.Exec(index).Error; err != nil { + return err + } + } + + return nil } const ( @@ -76,6 +138,12 @@ func (l *Log) BeforeSave(_ *gorm.DB) (err error) { if len(l.Content) > contentMaxSize { l.Content = common.TruncateByRune(l.Content, contentMaxSize) + "..." } + if l.TimestampTruncByDay == 0 { + l.TimestampTruncByDay = l.RequestAt.Truncate(24 * time.Hour).Unix() + } + if l.TimestampTruncByHour == 0 { + l.TimestampTruncByHour = l.RequestAt.Truncate(time.Hour).Unix() + } return } @@ -188,11 +256,10 @@ func RecordConsumeLog( return LogDB.Create(log).Error } -//nolint:goconst func getLogOrder(order string) string { prefix, suffix, _ := strings.Cut(order, "-") switch prefix { - case "used_amount", "token_id", "token_name", "group", "request_id", "request_at", "id", "created_at": + case "request_at", "id", "created_at": switch suffix { case "asc": return prefix + " asc" @@ -241,10 +308,11 @@ func buildGetLogsQuery( if group != "" { tx = tx.Where("group_id = ?", group) } - if !startTimestamp.IsZero() { + if !startTimestamp.IsZero() && !endTimestamp.IsZero() { + tx = tx.Where("request_at BETWEEN ? AND ?", startTimestamp, endTimestamp) + } else if !startTimestamp.IsZero() { tx = tx.Where("request_at >= ?", startTimestamp) - } - if !endTimestamp.IsZero() { + } else if !endTimestamp.IsZero() { tx = tx.Where("request_at <= ?", endTimestamp) } if tokenName != "" { @@ -396,7 +464,7 @@ func GetLogs( g.Go(func() error { var err error - models, err = getLogDistinctValues[string]("model", group, startTimestamp, endTimestamp) + models, err = getLogGroupByValues[string]("model", group, startTimestamp, endTimestamp) return err }) @@ -452,13 +520,13 @@ func GetGroupLogs( g.Go(func() error { var err error - tokenNames, err = getLogDistinctValues[string]("token_name", group, startTimestamp, endTimestamp) + tokenNames, err = getLogGroupByValues[string]("token_name", group, startTimestamp, endTimestamp) return err }) g.Go(func() error { var err error - models, err = getLogDistinctValues[string]("model", group, startTimestamp, endTimestamp) + models, err = getLogGroupByValues[string]("model", group, startTimestamp, endTimestamp) return err }) @@ -496,28 +564,32 @@ func buildSearchLogsQuery( tx = tx.Where("group_id = ?", group) } - // Handle exact match conditions for non-zero values - if !startTimestamp.IsZero() { - tx = tx.Where("request_at >= ?", startTimestamp) - } - if !endTimestamp.IsZero() { - tx = tx.Where("request_at <= ?", endTimestamp) - } if tokenName != "" { tx = tx.Where("token_name = ?", tokenName) } + if modelName != "" { tx = tx.Where("model = ?", modelName) } + + if !startTimestamp.IsZero() && !endTimestamp.IsZero() { + tx = tx.Where("request_at BETWEEN ? AND ?", startTimestamp, endTimestamp) + } else if !startTimestamp.IsZero() { + tx = tx.Where("request_at >= ?", startTimestamp) + } else if !endTimestamp.IsZero() { + tx = tx.Where("request_at <= ?", endTimestamp) + } + + if requestID != "" { + tx = tx.Where("request_id = ?", requestID) + } + if mode != 0 { tx = tx.Where("mode = ?", mode) } if endpoint != "" { tx = tx.Where("endpoint = ?", endpoint) } - if requestID != "" { - tx = tx.Where("request_id = ?", requestID) - } if tokenID != 0 { tx = tx.Where("token_id = ?", tokenID) } @@ -543,21 +615,6 @@ func buildSearchLogsQuery( conditions = append(conditions, "group_id = ?") values = append(values, keyword) } - - if num := String2Int(keyword); num != 0 { - if channelID == 0 { - conditions = append(conditions, "channel_id = ?") - values = append(values, num) - } - if mode != 0 { - conditions = append(conditions, "mode = ?") - values = append(values, num) - } - } - if requestID == "" { - conditions = append(conditions, "request_id = ?") - values = append(values, keyword) - } if tokenName == "" { conditions = append(conditions, "token_name = ?") values = append(values, keyword) @@ -566,12 +623,27 @@ func buildSearchLogsQuery( conditions = append(conditions, "model = ?") values = append(values, keyword) } - - if ip != "" { - conditions = append(conditions, "ip = ?") - values = append(values, ip) + if requestID == "" { + conditions = append(conditions, "request_id = ?") + values = append(values, keyword) } + // if num := String2Int(keyword); num != 0 { + // if channelID == 0 { + // conditions = append(conditions, "channel_id = ?") + // values = append(values, num) + // } + // if mode != 0 { + // conditions = append(conditions, "mode = ?") + // values = append(values, num) + // } + // } + + // if ip != "" { + // conditions = append(conditions, "ip = ?") + // values = append(values, ip) + // } + // if endpoint == "" { // if common.UsingPostgreSQL { // conditions = append(conditions, "endpoint ILIKE ?") @@ -581,12 +653,13 @@ func buildSearchLogsQuery( // values = append(values, "%"+keyword+"%") // } - if common.UsingPostgreSQL { - conditions = append(conditions, "content ILIKE ?") - } else { - conditions = append(conditions, "content LIKE ?") - } - values = append(values, "%"+keyword+"%") + // slow query + // if common.UsingPostgreSQL { + // conditions = append(conditions, "content ILIKE ?") + // } else { + // conditions = append(conditions, "content LIKE ?") + // } + // values = append(values, "%"+keyword+"%") if len(conditions) > 0 { tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...) @@ -717,7 +790,7 @@ func SearchLogs( g.Go(func() error { var err error - models, err = getLogDistinctValues[string]("model", group, startTimestamp, endTimestamp) + models, err = getLogGroupByValues[string]("model", group, startTimestamp, endTimestamp) return err }) @@ -774,13 +847,13 @@ func SearchGroupLogs( g.Go(func() error { var err error - tokenNames, err = getLogDistinctValues[string]("token_name", group, startTimestamp, endTimestamp) + tokenNames, err = getLogGroupByValues[string]("token_name", group, startTimestamp, endTimestamp) return err }) g.Go(func() error { var err error - models, err = getLogDistinctValues[string]("model", group, startTimestamp, endTimestamp) + models, err = getLogGroupByValues[string]("model", group, startTimestamp, endTimestamp) return err }) @@ -835,20 +908,25 @@ type GroupDashboardResponse struct { TokenNames []string `json:"token_names"` } -func getTimeSpanFormat(timeSpan time.Duration) string { - switch { - case common.UsingMySQL: - return fmt.Sprintf("UNIX_TIMESTAMP(DATE_FORMAT(request_at, '%%Y-%%m-%%d %%H:%%i:00')) DIV %d * %d", int64(timeSpan.Seconds()), int64(timeSpan.Seconds())) - case common.UsingPostgreSQL: - return fmt.Sprintf("FLOOR(EXTRACT(EPOCH FROM date_trunc('minute', request_at)) / %d) * %d", int64(timeSpan.Seconds()), int64(timeSpan.Seconds())) - case common.UsingSQLite: - return fmt.Sprintf("CAST(STRFTIME('%%s', STRFTIME('%%Y-%%m-%%d %%H:%%M:00', request_at)) AS INTEGER) / %d * %d", int64(timeSpan.Seconds()), int64(timeSpan.Seconds())) +type TimeSpanType string + +const ( + TimeSpanDay TimeSpanType = "day" + TimeSpanHour TimeSpanType = "hour" +) + +func getTimeSpanFormat(t TimeSpanType) string { + switch t { + case TimeSpanDay: + return "timestamp_trunc_by_day" + case TimeSpanHour: + return "timestamp_trunc_by_hour" default: return "" } } -func getChartData(group string, start, end time.Time, tokenName, modelName string, timeSpan time.Duration) ([]*ChartData, error) { +func getChartData(group string, start, end time.Time, tokenName, modelName string, timeSpan TimeSpanType) ([]*ChartData, error) { var chartData []*ChartData timeSpanFormat := getTimeSpanFormat(timeSpan) @@ -864,10 +942,12 @@ func getChartData(group string, start, end time.Time, tokenName, modelName strin if group != "" { query = query.Where("group_id = ?", group) } - if !start.IsZero() { + + if !start.IsZero() && !end.IsZero() { + query = query.Where("request_at BETWEEN ? AND ?", start, end) + } else if !start.IsZero() { query = query.Where("request_at >= ?", start) - } - if !end.IsZero() { + } else if !end.IsZero() { query = query.Where("request_at <= ?", end) } @@ -883,24 +963,55 @@ func getChartData(group string, start, end time.Time, tokenName, modelName strin return chartData, err } +//nolint:unused func getLogDistinctValues[T cmp.Ordered](field string, group string, start, end time.Time) ([]T, error) { var values []T query := LogDB. - Model(&Log{}). - Distinct(field) + Model(&Log{}) if group != "" { query = query.Where("group_id = ?", group) } - if !start.IsZero() { + if !start.IsZero() && !end.IsZero() { + query = query.Where("request_at BETWEEN ? AND ?", start, end) + } else if !start.IsZero() { query = query.Where("request_at >= ?", start) + } else if !end.IsZero() { + query = query.Where("request_at <= ?", end) } - if !end.IsZero() { + + err := query. + Distinct(field). + Pluck(field, &values).Error + if err != nil { + return nil, err + } + slices.Sort(values) + return values, nil +} + +func getLogGroupByValues[T cmp.Ordered](field string, group string, start, end time.Time) ([]T, error) { + var values []T + query := LogDB. + Model(&Log{}) + + if group != "" { + query = query.Where("group_id = ?", group) + } + + if !start.IsZero() && !end.IsZero() { + query = query.Where("request_at BETWEEN ? AND ?", start, end) + } else if !start.IsZero() { + query = query.Where("request_at >= ?", start) + } else if !end.IsZero() { query = query.Where("request_at <= ?", end) } - err := query.Pluck(field, &values).Error + err := query. + Select(field). + Group(field). + Pluck(field, &values).Error if err != nil { return nil, err } @@ -933,8 +1044,7 @@ func sumUsedAmount(chartData []*ChartData) float64 { } func getRPM(group string, end time.Time, tokenName, modelName string) (int64, error) { - query := LogDB.Model(&Log{}). - Where("request_at >= ? AND request_at <= ?", end.Add(-time.Minute), end) + query := LogDB.Model(&Log{}) if group != "" { query = query.Where("group_id = ?", group) @@ -947,7 +1057,9 @@ func getRPM(group string, end time.Time, tokenName, modelName string) (int64, er } var count int64 - err := query.Count(&count).Error + err := query. + Where("request_at BETWEEN ? AND ?", end.Add(-time.Minute), end). + Count(&count).Error return count, err } @@ -971,7 +1083,7 @@ func getTPM(group string, end time.Time, tokenName, modelName string) (int64, er return tpm, err } -func GetDashboardData(start, end time.Time, modelName string, timeSpan time.Duration) (*DashboardResponse, error) { +func GetDashboardData(start, end time.Time, modelName string, timeSpan TimeSpanType) (*DashboardResponse, error) { if end.IsZero() { end = time.Now() } else if end.Before(start) { @@ -995,7 +1107,7 @@ func GetDashboardData(start, end time.Time, modelName string, timeSpan time.Dura g.Go(func() error { var err error - models, err = getLogDistinctValues[string]("model", "", start, end) + models, err = getLogGroupByValues[string]("model", "", start, end) return err }) @@ -1030,7 +1142,7 @@ func GetDashboardData(start, end time.Time, modelName string, timeSpan time.Dura }, nil } -func GetGroupDashboardData(group string, start, end time.Time, tokenName string, modelName string, timeSpan time.Duration) (*GroupDashboardResponse, error) { +func GetGroupDashboardData(group string, start, end time.Time, tokenName string, modelName string, timeSpan TimeSpanType) (*GroupDashboardResponse, error) { if group == "" { return nil, errors.New("group is required") } @@ -1059,13 +1171,13 @@ func GetGroupDashboardData(group string, start, end time.Time, tokenName string, g.Go(func() error { var err error - tokenNames, err = getLogDistinctValues[string]("token_name", group, start, end) + tokenNames, err = getLogGroupByValues[string]("token_name", group, start, end) return err }) g.Go(func() error { var err error - models, err = getLogDistinctValues[string]("model", group, start, end) + models, err = getLogGroupByValues[string]("model", group, start, end) return err }) diff --git a/service/aiproxy/model/main.go b/service/aiproxy/model/main.go index 6d5cbb109fd..717936535ae 100644 --- a/service/aiproxy/model/main.go +++ b/service/aiproxy/model/main.go @@ -32,13 +32,19 @@ func chooseDB(envName string) (*gorm.DB, error) { switch { case strings.HasPrefix(dsn, "postgres"): // Use PostgreSQL - return openPostgreSQL(dsn) - case dsn != "": + log.Info("using PostgreSQL as database") + common.UsingPostgreSQL = true + return OpenPostgreSQL(dsn) + case strings.HasPrefix(dsn, "mysql"): // Use MySQL - return openMySQL(dsn) + log.Info("using MySQL as database") + common.UsingMySQL = true + return OpenMySQL(dsn) default: // Use SQLite - return openSQLite() + log.Info("SQL_DSN not set, using SQLite as database: ", common.SQLitePath) + common.UsingSQLite = true + return OpenSQLite(common.SQLitePath) } } @@ -61,9 +67,7 @@ func newDBLogger() gormLogger.Interface { ) } -func openPostgreSQL(dsn string) (*gorm.DB, error) { - log.Info("using PostgreSQL as database") - common.UsingPostgreSQL = true +func OpenPostgreSQL(dsn string) (*gorm.DB, error) { return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, // disables implicit prepared statement usage @@ -76,10 +80,8 @@ func openPostgreSQL(dsn string) (*gorm.DB, error) { }) } -func openMySQL(dsn string) (*gorm.DB, error) { - log.Info("using MySQL as database") - common.UsingMySQL = true - return gorm.Open(mysql.Open(dsn), &gorm.Config{ +func OpenMySQL(dsn string) (*gorm.DB, error) { + return gorm.Open(mysql.Open(strings.TrimPrefix(dsn, "mysql://")), &gorm.Config{ PrepareStmt: true, // precompile SQL TranslateError: true, Logger: newDBLogger(), @@ -88,17 +90,13 @@ func openMySQL(dsn string) (*gorm.DB, error) { }) } -func openSQLite() (*gorm.DB, error) { - log.Info("SQL_DSN not set, using SQLite as database: ", common.SQLitePath) - common.UsingSQLite = true - - baseDir := filepath.Dir(common.SQLitePath) +func OpenSQLite(sqlitePath string) (*gorm.DB, error) { + baseDir := filepath.Dir(sqlitePath) if err := os.MkdirAll(baseDir, 0o755); err != nil { - log.Fatal("failed to create base directory: " + err.Error()) - return nil, err + return nil, fmt.Errorf("failed to create base directory: %w", err) } - dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) + dsn := fmt.Sprintf("%s?_busy_timeout=%d", sqlitePath, common.SQLiteBusyTimeout) return gorm.Open(sqlite.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL TranslateError: true, @@ -184,11 +182,16 @@ func InitLogDB() { } func migrateLOGDB() error { - return LogDB.AutoMigrate( + err := LogDB.AutoMigrate( &Log{}, &RequestDetail{}, &ConsumeError{}, ) + if err != nil { + return err + } + + return CreateLogIndexes(LogDB) } func setDBConns(db *gorm.DB) { diff --git a/service/aiproxy/monitor/model.go b/service/aiproxy/monitor/model.go index fca8217efdb..19018e73cad 100644 --- a/service/aiproxy/monitor/model.go +++ b/service/aiproxy/monitor/model.go @@ -15,34 +15,76 @@ import ( // Redis key prefixes and patterns const ( - modelKeyPrefix = "model:" - bannedKeySuffix = ":banned" - statsKeySuffix = ":stats" - channelKeyPart = ":channel:" + modelKeyPrefix = "model:" + bannedKeySuffix = ":banned" + statsKeySuffix = ":stats" + modelTotalStatsSuffix = ":total_stats" + channelKeyPart = ":channel:" ) // Redis scripts var ( addRequestScript = redis.NewScript(addRequestLuaScript) getChannelModelErrorRateScript = redis.NewScript(getChannelModelErrorRateLuaScript) + getModelErrorRateScript = redis.NewScript(getModelErrorRateLuaScript) getBannedChannelsScript = redis.NewScript(getBannedChannelsLuaScript) clearChannelModelErrorsScript = redis.NewScript(clearChannelModelErrorsLuaScript) clearChannelAllModelErrorsScript = redis.NewScript(clearChannelAllModelErrorsLuaScript) clearAllModelErrorsScript = redis.NewScript(clearAllModelErrorsLuaScript) ) -// Helper functions -func isFeatureEnabled() bool { - return common.RedisEnabled && config.GetEnableModelErrorAutoBan() -} - func buildStatsKey(model string, channelID interface{}) string { return fmt.Sprintf("%s%s%s%v%s", modelKeyPrefix, model, channelKeyPart, channelID, statsKeySuffix) } +// GetModelErrorRate gets error rate for a specific model across all channels +func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) { + if !common.RedisEnabled { + return map[string]float64{}, nil + } + + result := make(map[string]float64) + pattern := modelKeyPrefix + "*" + modelTotalStatsSuffix + + iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator() + for iter.Next(ctx) { + key := iter.Val() + parts := strings.Split(key, ":") + if len(parts) != 3 || parts[2] != "total_stats" { + continue + } + model := parts[1] + + rate, err := getModelErrorRateScript.Run( + ctx, + common.RDB, + []string{key}, + time.Now().UnixMilli(), + ).Float64() + if err != nil { + return nil, err + } + + result[model] = rate + } + + if err := iter.Err(); err != nil { + return nil, err + } + + return result, nil +} + +func canAutoBan() int { + if common.RedisEnabled && config.GetEnableModelErrorAutoBan() { + return 1 + } + return 0 +} + // AddRequest adds a request record and checks if channel should be banned func AddRequest(ctx context.Context, model string, channelID int64, isError bool) error { - if !isFeatureEnabled() { + if !common.RedisEnabled { return nil } @@ -61,6 +103,7 @@ func AddRequest(ctx context.Context, model string, channelID int64, isError bool now, config.GetModelErrorAutoBanRate(), time.Second.Milliseconds()*15, + canAutoBan(), ).Int64() if err != nil { return err @@ -75,8 +118,8 @@ func AddRequest(ctx context.Context, model string, channelID int64, isError bool // GetChannelModelErrorRates gets error rates for a specific channel func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) { - if !isFeatureEnabled() { - return nil, nil + if !common.RedisEnabled { + return map[string]float64{}, nil } result := make(map[string]float64) @@ -114,8 +157,8 @@ func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string // GetBannedChannels gets banned channels for a specific model func GetBannedChannels(ctx context.Context, model string) ([]int64, error) { - if !isFeatureEnabled() { - return nil, nil + if !common.RedisEnabled || !config.GetEnableModelErrorAutoBan() { + return []int64{}, nil } result, err := getBannedChannelsScript.Run(ctx, common.RDB, []string{model}).Int64Slice() if err != nil { @@ -126,7 +169,7 @@ func GetBannedChannels(ctx context.Context, model string) ([]int64, error) { // ClearChannelModelErrors clears errors for a specific channel and model func ClearChannelModelErrors(ctx context.Context, model string, channelID int) error { - if !isFeatureEnabled() { + if !common.RedisEnabled { return nil } return clearChannelModelErrorsScript.Run( @@ -139,7 +182,7 @@ func ClearChannelModelErrors(ctx context.Context, model string, channelID int) e // ClearChannelAllModelErrors clears all errors for a specific channel func ClearChannelAllModelErrors(ctx context.Context, channelID int) error { - if !isFeatureEnabled() { + if !common.RedisEnabled { return nil } return clearChannelAllModelErrorsScript.Run( @@ -152,7 +195,7 @@ func ClearChannelAllModelErrors(ctx context.Context, channelID int) error { // ClearAllModelErrors clears all error records func ClearAllModelErrors(ctx context.Context) error { - if !isFeatureEnabled() { + if !common.RedisEnabled { return nil } return clearAllModelErrorsScript.Run(ctx, common.RDB, []string{}).Err() @@ -160,8 +203,8 @@ func ClearAllModelErrors(ctx context.Context) error { // GetAllBannedChannels gets all banned channels for all models func GetAllBannedChannels(ctx context.Context) (map[string][]int64, error) { - if !isFeatureEnabled() { - return nil, nil + if !common.RedisEnabled || !config.GetEnableModelErrorAutoBan() { + return map[string][]int64{}, nil } result := make(map[string][]int64) @@ -191,8 +234,8 @@ func GetAllBannedChannels(ctx context.Context) (map[string][]int64, error) { // GetAllChannelModelErrorRates gets error rates for all channels and models func GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) { - if !isFeatureEnabled() { - return nil, nil + if !common.RedisEnabled { + return map[int64]map[string]float64{}, nil } result := make(map[int64]map[string]float64) @@ -245,9 +288,11 @@ local is_error = tonumber(ARGV[2]) local now_ts = tonumber(ARGV[3]) local max_error_rate = tonumber(ARGV[4]) local statsExpiry = tonumber(ARGV[5]) +local can_auto_ban = tonumber(ARGV[6]) local banned_key = "model:" .. model .. ":banned" local stats_key = "model:" .. model .. ":channel:" .. channel_id .. ":stats" +local model_stats_key = "model:" .. model .. ":total_stats" local maxSliceCount = 6 local current_slice = math.floor(now_ts / 1000) @@ -261,7 +306,7 @@ local function parse_req_err(value) return tonumber(r) or 0, tonumber(e) or 0 end -local function update_current_slice() +local function update_channel_stats() local req, err = parse_req_err(redis.call("HGET", stats_key, current_slice)) req = req + 1 err = err + (is_error == 1 and 1 or 0) @@ -270,7 +315,23 @@ local function update_current_slice() return req, err end -local function calculate_error_rate() +local function update_model_stats() + local req, err = parse_req_err(redis.call("HGET", model_stats_key, current_slice)) + req = req + 1 + err = err + (is_error == 1 and 1 or 0) + redis.call("HSET", model_stats_key, current_slice, req .. ":" .. err) + redis.call("PEXPIRE", model_stats_key, statsExpiry) + return req, err +end + +update_channel_stats() +update_model_stats() + +if is_error == 0 or can_auto_ban == 0 then + return 0 +end + +local function check_channel_error() local total_req, total_err = 0, 0 local min_valid_slice = current_slice - maxSliceCount @@ -292,24 +353,47 @@ local function calculate_error_rate() redis.call("HDEL", stats_key, unpack(to_delete)) end - return total_req, total_err + if total_req >= 10 and (total_err / total_req) >= max_error_rate then + redis.call("SADD", banned_key, channel_id) + redis.call("DEL", stats_key) + return true + end + return false end -update_current_slice() +if check_channel_error() then + return 1 +end +return 0 +` -if is_error == 0 then - return 0 + getModelErrorRateLuaScript = ` +local model_stats_key = KEYS[1] +local now_ts = tonumber(ARGV[1]) +local maxSliceCount = 6 +local current_slice = math.floor(now_ts / 1000) +local min_valid_slice = current_slice - maxSliceCount + +local function parse_req_err(value) + if not value then return 0, 0 end + local r, e = value:match("^(%d+):(%d+)$") + return tonumber(r) or 0, tonumber(e) or 0 end -local total_req, total_err = calculate_error_rate() +local total_req, total_err = 0, 0 +local all_slices = redis.call("HGETALL", model_stats_key) -if total_req >= 10 and (total_err / total_req) >= max_error_rate then - redis.call("SADD", banned_key, channel_id) - redis.call("DEL", stats_key) - return 1 +for i = 1, #all_slices, 2 do + local slice = tonumber(all_slices[i]) + if slice >= min_valid_slice then + local req, err = parse_req_err(all_slices[i+1]) + total_req = total_req + req + total_err = total_err + err + end end -return 0 +if total_req == 0 then return 0 end +return total_err / total_req ` getChannelModelErrorRateLuaScript = ` diff --git a/service/aiproxy/relay/adaptor/ali/adaptor.go b/service/aiproxy/relay/adaptor/ali/adaptor.go index ab48abb942e..81191c7722e 100644 --- a/service/aiproxy/relay/adaptor/ali/adaptor.go +++ b/service/aiproxy/relay/adaptor/ali/adaptor.go @@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { u = baseURL } switch meta.Mode { - case relaymode.Embeddings: - return u + "/api/v1/services/embeddings/text-embedding/text-embedding", nil case relaymode.ImagesGenerations: return u + "/api/v1/services/aigc/text2image/image-synthesis", nil case relaymode.ChatCompletions: return u + "/compatible-mode/v1/chat/completions", nil case relaymode.Completions: return u + "/compatible-mode/v1/completions", nil + case relaymode.Embeddings: + return u + "/compatible-mode/v1/embeddings", nil case relaymode.AudioSpeech, relaymode.AudioTranscription: return u + "/api-ws/v1/inference", nil case relaymode.Rerank: @@ -62,9 +62,7 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht return ConvertImageRequest(meta, req) case relaymode.Rerank: return ConvertRerankRequest(meta, req) - case relaymode.Embeddings: - return ConvertEmbeddingsRequest(meta, req) - case relaymode.ChatCompletions, relaymode.Completions: + case relaymode.ChatCompletions, relaymode.Completions, relaymode.Embeddings: return openai.ConvertRequest(meta, req) case relaymode.AudioSpeech: return ConvertTTSRequest(meta, req) @@ -105,11 +103,9 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons return &relaymodel.Usage{}, nil } switch meta.Mode { - case relaymode.Embeddings: - usage, err = EmbeddingsHandler(meta, c, resp) case relaymode.ImagesGenerations: usage, err = ImageHandler(meta, c, resp) - case relaymode.ChatCompletions, relaymode.Completions: + case relaymode.ChatCompletions, relaymode.Completions, relaymode.Embeddings: usage, err = openai.DoResponse(meta, c, resp) case relaymode.Rerank: usage, err = RerankHandler(meta, c, resp) diff --git a/service/aiproxy/relay/adaptor/ali/embeddings.go b/service/aiproxy/relay/adaptor/ali/embeddings.go index 877f6b66f50..b57577e454b 100644 --- a/service/aiproxy/relay/adaptor/ali/embeddings.go +++ b/service/aiproxy/relay/adaptor/ali/embeddings.go @@ -15,6 +15,9 @@ import ( relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" ) +// Deprecated: Use openai.ConvertRequest instead +// /api/v1/services/embeddings/text-embedding/text-embedding + func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) { var reqMap map[string]any err := common.UnmarshalBodyReusable(req, &reqMap) diff --git a/service/aiproxy/router/api.go b/service/aiproxy/router/api.go index 7006a085c0f..22df97ed2ad 100644 --- a/service/aiproxy/router/api.go +++ b/service/aiproxy/router/api.go @@ -77,6 +77,11 @@ func SetAPIRouter(router *gin.Engine) { channelsRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) channelsRoute.POST("/batch_delete", controller.DeleteChannels) channelsRoute.GET("/test", controller.TestAllChannels) + + importRoute := channelsRoute.Group("/import") + { + importRoute.POST("/oneapi", controller.ImportChannelFromOneAPI) + } } channelRoute := apiRouter.Group("/channel") { @@ -152,6 +157,7 @@ func SetAPIRouter(router *gin.Engine) { monitorRoute.DELETE("/", controller.ClearAllModelErrors) monitorRoute.DELETE("/:id", controller.ClearChannelAllModelErrors) monitorRoute.DELETE("/:id/:model", controller.ClearChannelModelErrors) + monitorRoute.GET("/models", controller.GetModelsErrorRate) } } }