diff --git a/service/aiproxy/common/network/ip.go b/service/aiproxy/common/network/ip.go index dcbe49bed0d..d1207a0f91c 100644 --- a/service/aiproxy/common/network/ip.go +++ b/service/aiproxy/common/network/ip.go @@ -3,17 +3,8 @@ package network import ( "fmt" "net" - "strings" ) -func splitSubnets(subnets string) []string { - res := strings.Split(subnets, ",") - for i := 0; i < len(res); i++ { - res[i] = strings.TrimSpace(res[i]) - } - return res -} - func isValidSubnet(subnet string) error { _, _, err := net.ParseCIDR(subnet) if err != nil { @@ -30,8 +21,8 @@ func isIPInSubnet(ip string, subnet string) (bool, error) { return ipNet.Contains(net.ParseIP(ip)), nil } -func IsValidSubnets(subnets string) error { - for _, subnet := range splitSubnets(subnets) { +func IsValidSubnets(subnets []string) error { + for _, subnet := range subnets { if err := isValidSubnet(subnet); err != nil { return err } @@ -39,8 +30,8 @@ func IsValidSubnets(subnets string) error { return nil } -func IsIPInSubnets(ip string, subnets string) (bool, error) { - for _, subnet := range splitSubnets(subnets) { +func IsIPInSubnets(ip string, subnets []string) (bool, error) { + for _, subnet := range subnets { if ok, err := isIPInSubnet(ip, subnet); err != nil { return false, err } else if ok { diff --git a/service/aiproxy/controller/group.go b/service/aiproxy/controller/group.go index bc0d0f16ed8..8442e30fb35 100644 --- a/service/aiproxy/controller/group.go +++ b/service/aiproxy/controller/group.go @@ -274,3 +274,28 @@ func CreateGroup(c *gin.Context) { } middleware.SuccessResponse(c, nil) } + +func UpdateGroup(c *gin.Context) { + group := c.Param("group") + if group == "" { + middleware.ErrorResponse(c, http.StatusOK, "invalid parameter") + return + } + req := CreateGroupRequest{} + err := json.NewDecoder(c.Request.Body).Decode(&req) + if err != nil { + middleware.ErrorResponse(c, http.StatusOK, "invalid parameter") + return + } + err = model.UpdateGroup(group, &model.Group{ + RPMRatio: req.RPMRatio, + RPM: req.RPM, + TPMRatio: req.TPMRatio, + TPM: req.TPM, + }) + if err != nil { + middleware.ErrorResponse(c, http.StatusOK, err.Error()) + return + } + middleware.SuccessResponse(c, nil) +} diff --git a/service/aiproxy/controller/token.go b/service/aiproxy/controller/token.go index 19aaf4b15f3..a3dbdc77379 100644 --- a/service/aiproxy/controller/token.go +++ b/service/aiproxy/controller/token.go @@ -39,7 +39,7 @@ func (t *TokenResponse) MarshalJSON() ([]byte, error) { type ( AddTokenRequest struct { Name string `json:"name"` - Subnet string `json:"subnet"` + Subnets []string `json:"subnets"` Models []string `json:"models"` ExpiredAt int64 `json:"expiredAt"` Quota float64 `json:"quota"` @@ -54,6 +54,20 @@ type ( } ) +func (at *AddTokenRequest) ToToken() *model.Token { + var expiredAt time.Time + if at.ExpiredAt > 0 { + expiredAt = time.UnixMilli(at.ExpiredAt) + } + return &model.Token{ + Name: model.EmptyNullString(at.Name), + Subnets: at.Subnets, + Models: at.Models, + ExpiredAt: expiredAt, + Quota: at.Quota, + } +} + func validateToken(token AddTokenRequest) error { if token.Name == "" { return errors.New("token name cannot be empty") @@ -61,20 +75,15 @@ func validateToken(token AddTokenRequest) error { if len(token.Name) > 30 { return errors.New("token name is too long") } - if token.Subnet != "" { - if err := network.IsValidSubnets(token.Subnet); err != nil { - return fmt.Errorf("invalid subnet: %w", err) - } + if err := network.IsValidSubnets(token.Subnets); err != nil { + return fmt.Errorf("invalid subnet: %w", err) } return nil } -func validateTokenStatus(token *model.Token) error { - if token.Status == model.TokenStatusExpired && !token.ExpiredAt.IsZero() && token.ExpiredAt.Before(time.Now()) { - return errors.New("token expired, please update token expired time or set to never expire") - } - if token.Status == model.TokenStatusExhausted && token.Quota > 0 && token.UsedAmount >= token.Quota { - return errors.New("token quota exhausted, please update token quota or set to unlimited quota") +func validateTokenUpdate(token AddTokenRequest) error { + if err := network.IsValidSubnets(token.Subnets); err != nil { + return fmt.Errorf("invalid subnet: %w", err) } return nil } @@ -223,7 +232,7 @@ func GetGroupToken(c *gin.Context) { middleware.SuccessResponse(c, buildTokenResponse(token)) } -func AddToken(c *gin.Context) { +func AddGroupToken(c *gin.Context) { group := c.Param("group") var req AddTokenRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -236,20 +245,9 @@ func AddToken(c *gin.Context) { return } - var expiredAt time.Time - if req.ExpiredAt > 0 { - expiredAt = time.UnixMilli(req.ExpiredAt) - } - - token := &model.Token{ - GroupID: group, - Name: model.EmptyNullString(req.Name), - Key: random.GenerateKey(), - ExpiredAt: expiredAt, - Quota: req.Quota, - Models: req.Models, - Subnet: req.Subnet, - } + token := req.ToToken() + token.GroupID = group + token.Key = random.GenerateKey() if err := model.InsertToken(token, c.Query("auto_create_group") == "true"); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) @@ -336,29 +334,14 @@ func UpdateToken(c *gin.Context) { return } - if err := validateToken(req); err != nil { + if err := validateTokenUpdate(req); err != nil { middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error()) return } - token, err := model.GetTokenByID(id) - if err != nil { - middleware.ErrorResponse(c, http.StatusOK, err.Error()) - return - } - - var expiredAt time.Time - if req.ExpiredAt > 0 { - expiredAt = time.UnixMilli(req.ExpiredAt) - } - - token.Name = model.EmptyNullString(req.Name) - token.ExpiredAt = expiredAt - token.Quota = req.Quota - token.Models = req.Models - token.Subnet = req.Subnet + token := req.ToToken() - if err := model.UpdateToken(token); err != nil { + if err := model.UpdateToken(id, token); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } @@ -380,29 +363,14 @@ func UpdateGroupToken(c *gin.Context) { return } - if err := validateToken(req); err != nil { + if err := validateTokenUpdate(req); err != nil { middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error()) return } - token, err := model.GetGroupTokenByID(group, id) - if err != nil { - middleware.ErrorResponse(c, http.StatusOK, err.Error()) - return - } - - var expiredAt time.Time - if req.ExpiredAt > 0 { - expiredAt = time.UnixMilli(req.ExpiredAt) - } - - token.Name = model.EmptyNullString(req.Name) - token.ExpiredAt = expiredAt - token.Quota = req.Quota - token.Models = req.Models - token.Subnet = req.Subnet + token := req.ToToken() - if err := model.UpdateToken(token); err != nil { + if err := model.UpdateGroupToken(id, group, token); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return } @@ -423,19 +391,6 @@ func UpdateTokenStatus(c *gin.Context) { return } - token, err := model.GetTokenByID(id) - if err != nil { - middleware.ErrorResponse(c, http.StatusOK, err.Error()) - return - } - - if req.Status == model.TokenStatusEnabled { - if err := validateTokenStatus(token); err != nil { - middleware.ErrorResponse(c, http.StatusOK, err.Error()) - return - } - } - if err := model.UpdateTokenStatus(id, req.Status); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return @@ -458,19 +413,6 @@ func UpdateGroupTokenStatus(c *gin.Context) { return } - token, err := model.GetGroupTokenByID(group, id) - if err != nil { - middleware.ErrorResponse(c, http.StatusOK, err.Error()) - return - } - - if req.Status == model.TokenStatusEnabled { - if err := validateTokenStatus(token); err != nil { - middleware.ErrorResponse(c, http.StatusOK, err.Error()) - return - } - } - if err := model.UpdateGroupTokenStatus(group, id, req.Status); err != nil { middleware.ErrorResponse(c, http.StatusOK, err.Error()) return diff --git a/service/aiproxy/middleware/auth.go b/service/aiproxy/middleware/auth.go index 4d00b34133b..a9b09e321b4 100644 --- a/service/aiproxy/middleware/auth.go +++ b/service/aiproxy/middleware/auth.go @@ -75,21 +75,19 @@ func TokenAuth(c *gin.Context) { SetLogTokenFields(log.Data, token, useInternalToken) - if token.Subnet != "" { - if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnet); err != nil { - abortLogWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } else if !ok { - abortLogWithMessage(c, http.StatusForbidden, - fmt.Sprintf("token (%s[%d]) can only be used in the specified subnet: %s, current ip: %s", - token.Name, - token.ID, - token.Subnet, - c.ClientIP(), - ), - ) - return - } + if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil { + abortLogWithMessage(c, http.StatusInternalServerError, err.Error()) + return + } else if !ok { + abortLogWithMessage(c, http.StatusForbidden, + fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s", + token.Name, + token.ID, + token.Subnets, + c.ClientIP(), + ), + ) + return } var group *model.GroupCache diff --git a/service/aiproxy/model/cache.go b/service/aiproxy/model/cache.go index 927509c31d5..2f6d4767a4e 100644 --- a/service/aiproxy/model/cache.go +++ b/service/aiproxy/model/cache.go @@ -63,7 +63,7 @@ type TokenCache struct { Group string `json:"group" redis:"g"` Key string `json:"-" redis:"-"` Name string `json:"name" redis:"n"` - Subnet string `json:"subnet" redis:"s"` + Subnets redisStringSlice `json:"subnets" redis:"s"` Models redisStringSlice `json:"models" redis:"m"` ID int `json:"id" redis:"i"` Status int `json:"status" redis:"st"` @@ -78,7 +78,7 @@ func (t *Token) ToTokenCache() *TokenCache { Key: t.Key, Name: t.Name.String(), Models: t.Models, - Subnet: t.Subnet, + Subnets: t.Subnets, Status: t.Status, ExpiredAt: redisTime(t.ExpiredAt), Quota: t.Quota, diff --git a/service/aiproxy/model/group.go b/service/aiproxy/model/group.go index 63f93369cf2..f6922381b3d 100644 --- a/service/aiproxy/model/group.go +++ b/service/aiproxy/model/group.go @@ -131,6 +131,25 @@ func DeleteGroupsByIDs(ids []string) (err error) { }) } +func UpdateGroup(id string, group *Group) (err error) { + if id == "" { + return errors.New("group id is empty") + } + defer func() { + if err == nil { + if err := CacheDeleteGroup(id); err != nil { + log.Error("cache delete group failed: " + err.Error()) + } + } + }() + result := DB. + Clauses(clause.Returning{}). + Where("id = ?", id). + Select("rpm_ratio", "rpm", "tpm_ratio", "tpm"). + Updates(group) + return HandleUpdateResult(result, ErrGroupNotFound) +} + func UpdateGroupUsedAmountAndRequestCount(id string, amount float64, count int) (err error) { group := &Group{ID: id} defer func() { diff --git a/service/aiproxy/model/token.go b/service/aiproxy/model/token.go index 8ed234002e1..78f8ff7e3b0 100644 --- a/service/aiproxy/model/token.go +++ b/service/aiproxy/model/token.go @@ -18,10 +18,8 @@ const ( ) const ( - TokenStatusEnabled = 1 // don't use 0, 0 is the default value! - TokenStatusDisabled = 2 // also don't use 0 - TokenStatusExpired = 3 - TokenStatusExhausted = 4 + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 ) type Token struct { @@ -31,7 +29,7 @@ type Token struct { Key string `gorm:"type:char(48);uniqueIndex" json:"key"` Name EmptyNullString `gorm:"index;uniqueIndex:idx_group_name;not null" json:"name"` GroupID string `gorm:"index;uniqueIndex:idx_group_name" json:"group"` - Subnet string `json:"subnet"` + Subnets []string `gorm:"serializer:fastjson;type:text" json:"subnets"` Models []string `gorm:"serializer:fastjson;type:text" json:"models"` Status int `gorm:"default:1;index" json:"status"` ID int `gorm:"primaryKey" json:"id"` @@ -170,23 +168,14 @@ func SearchTokens(group string, keyword string, startIdx int, num int, order str } func GetTokenByKey(key string) (*Token, error) { + if key == "" { + return nil, errors.New("key is empty") + } var token Token err := DB.Where("key = ?", key).First(&token).Error return &token, HandleNotFound(err, ErrTokenNotFound) } -func GetTokenUsedAmount(id int) (float64, error) { - var amount float64 - err := DB.Model(&Token{}).Where("id = ?", id).Select("used_amount").Scan(&amount).Error - return amount, HandleNotFound(err, ErrTokenNotFound) -} - -func GetTokenUsedAmountByKey(key string) (float64, error) { - var amount float64 - err := DB.Model(&Token{}).Where("key = ?", key).Select("used_amount").Scan(&amount).Error - return amount, HandleNotFound(err, ErrTokenNotFound) -} - func ValidateAndGetToken(key string) (token *TokenCache, err error) { if key == "" { return nil, errors.New("no token provided") @@ -199,30 +188,13 @@ func ValidateAndGetToken(key string) (token *TokenCache, err error) { log.Error("get token from cache failed: " + err.Error()) return nil, errors.New("token validation failed") } - switch token.Status { - case TokenStatusExhausted: - return nil, fmt.Errorf("token (%s[%d]) quota is exhausted", token.Name, token.ID) - case TokenStatusExpired: - return nil, fmt.Errorf("token (%s[%d]) is expired", token.Name, token.ID) - case TokenStatusDisabled: + if token.Status == TokenStatusDisabled { return nil, fmt.Errorf("token (%s[%d]) is disabled", token.Name, token.ID) } - if token.Status != TokenStatusEnabled { - return nil, fmt.Errorf("token (%s[%d]) is not available", token.Name, token.ID) - } if !time.Time(token.ExpiredAt).IsZero() && time.Time(token.ExpiredAt).Before(time.Now()) { - err := UpdateTokenStatus(token.ID, TokenStatusExpired) - if err != nil { - log.Error("failed to update token status" + err.Error()) - } return nil, fmt.Errorf("token (%s[%d]) is expired", token.Name, token.ID) } if token.Quota > 0 && token.UsedAmount >= token.Quota { - // in this case, we can make sure the token is exhausted - err := UpdateTokenStatus(token.ID, TokenStatusExhausted) - if err != nil { - log.Error("failed to update token status" + err.Error()) - } return nil, fmt.Errorf("token (%s[%d]) quota is exhausted", token.Name, token.ID) } return token, nil @@ -274,6 +246,9 @@ func UpdateTokenStatus(id int, status int) (err error) { } func UpdateGroupTokenStatus(group string, id int, status int) (err error) { + if id == 0 || group == "" { + return errors.New("id or group is empty") + } token := Token{} defer func() { if err == nil { @@ -300,7 +275,7 @@ func UpdateGroupTokenStatus(group string, id int, status int) (err error) { func DeleteGroupTokenByID(groupID string, id int) (err error) { if id == 0 || groupID == "" { - return errors.New("id 或 group 为空!") + return errors.New("id or group is empty") } token := Token{ID: id, GroupID: groupID} defer func() { @@ -321,7 +296,10 @@ func DeleteGroupTokenByID(groupID string, id int) (err error) { return HandleUpdateResult(result, ErrTokenNotFound) } -func DeleteGroupTokensByIDs(groupID string, ids []int) (err error) { +func DeleteGroupTokensByIDs(group string, ids []int) (err error) { + if group == "" { + return errors.New("group is empty") + } if len(ids) == 0 { return nil } @@ -342,7 +320,7 @@ func DeleteGroupTokensByIDs(groupID string, ids []int) (err error) { {Name: "key"}, }, }). - Where("group_id = ?", groupID). + Where("group_id = ?", group). Where("id IN (?)", ids). Delete(&tokens). Error @@ -351,7 +329,7 @@ func DeleteGroupTokensByIDs(groupID string, ids []int) (err error) { func DeleteTokenByID(id int) (err error) { if id == 0 { - return errors.New("id 为空!") + return errors.New("id is empty") } token := Token{ID: id} defer func() { @@ -399,7 +377,35 @@ func DeleteTokensByIDs(ids []int) (err error) { }) } -func UpdateToken(token *Token) (err error) { +func UpdateToken(id int, token *Token) (err error) { + if id == 0 { + return errors.New("id is empty") + } + defer func() { + if err == nil { + if err := CacheDeleteToken(token.Key); err != nil { + log.Error("delete token from cache failed: " + err.Error()) + } + } + }() + result := DB. + Select("subnets", "quota", "models", "expired_at"). + Where("id = ?", id). + Clauses(clause.Returning{}). + Updates(token) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrDuplicatedKey) { + return errors.New("token name already exists in this group") + } + } + return HandleUpdateResult(result, ErrTokenNotFound) +} + +func UpdateGroupToken(id int, group string, token *Token) (err error) { + if id == 0 || group == "" { + return errors.New("id or group is empty") + } + defer func() { if err == nil { if err := CacheDeleteToken(token.Key); err != nil { @@ -407,7 +413,11 @@ func UpdateToken(token *Token) (err error) { } } }() - result := DB.Omit("created_at", "status", "key", "group_id", "used_amount", "request_count").Save(token) + result := DB. + Select("subnets", "quota", "models", "expired_at"). + Where("id = ? and group_id = ?", id, group). + Clauses(clause.Returning{}). + Updates(token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrDuplicatedKey) { return errors.New("token name already exists in this group") diff --git a/service/aiproxy/router/api.go b/service/aiproxy/router/api.go index 54a1b4dacf7..489de8028fc 100644 --- a/service/aiproxy/router/api.go +++ b/service/aiproxy/router/api.go @@ -49,6 +49,7 @@ func SetAPIRouter(router *gin.Engine) { groupRoute := apiRouter.Group("/group") { groupRoute.POST("/:group", controller.CreateGroup) + groupRoute.PUT("/:group", controller.UpdateGroup) groupRoute.GET("/:group", controller.GetGroup) groupRoute.DELETE("/:group", controller.DeleteGroup) groupRoute.POST("/:group/status", controller.UpdateGroupStatus) @@ -112,7 +113,7 @@ func SetAPIRouter(router *gin.Engine) { tokenRoute.POST("/:group/batch_delete", controller.DeleteGroupTokens) tokenRoute.GET("/:group", controller.GetGroupTokens) tokenRoute.GET("/:group/:id", controller.GetGroupToken) - tokenRoute.POST("/:group", controller.AddToken) + tokenRoute.POST("/:group", controller.AddGroupToken) tokenRoute.PUT("/:group/:id", controller.UpdateGroupToken) tokenRoute.POST("/:group/:id/status", controller.UpdateGroupTokenStatus) tokenRoute.POST("/:group/:id/name", controller.UpdateGroupTokenName)