Skip to content

Commit

Permalink
feat: token and group model update (#5419)
Browse files Browse the repository at this point in the history
* feat: update token model

* feat: update group
  • Loading branch information
zijiren233 authored Feb 27, 2025
1 parent 1ae85d8 commit c6f9542
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 160 deletions.
17 changes: 4 additions & 13 deletions service/aiproxy/common/network/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@ package network
import (
"fmt"
"net"
"strings"
)

func splitSubnets(subnets string) []string {
res := strings.Split(subnets, ",")
for i := 0; i < len(res); i++ {
res[i] = strings.TrimSpace(res[i])
}
return res
}

func isValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
if err != nil {
Expand All @@ -30,17 +21,17 @@ func isIPInSubnet(ip string, subnet string) (bool, error) {
return ipNet.Contains(net.ParseIP(ip)), nil
}

func IsValidSubnets(subnets string) error {
for _, subnet := range splitSubnets(subnets) {
func IsValidSubnets(subnets []string) error {
for _, subnet := range subnets {
if err := isValidSubnet(subnet); err != nil {
return err
}
}
return nil
}

func IsIPInSubnets(ip string, subnets string) (bool, error) {
for _, subnet := range splitSubnets(subnets) {
func IsIPInSubnets(ip string, subnets []string) (bool, error) {
for _, subnet := range subnets {
if ok, err := isIPInSubnet(ip, subnet); err != nil {
return false, err
} else if ok {
Expand Down
25 changes: 25 additions & 0 deletions service/aiproxy/controller/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,28 @@ func CreateGroup(c *gin.Context) {
}
middleware.SuccessResponse(c, nil)
}

func UpdateGroup(c *gin.Context) {
group := c.Param("group")
if group == "" {
middleware.ErrorResponse(c, http.StatusOK, "invalid parameter")
return
}
req := CreateGroupRequest{}
err := json.NewDecoder(c.Request.Body).Decode(&req)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, "invalid parameter")
return
}
err = model.UpdateGroup(group, &model.Group{
RPMRatio: req.RPMRatio,
RPM: req.RPM,
TPMRatio: req.TPMRatio,
TPM: req.TPM,
})
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
middleware.SuccessResponse(c, nil)
}
118 changes: 30 additions & 88 deletions service/aiproxy/controller/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (t *TokenResponse) MarshalJSON() ([]byte, error) {
type (
AddTokenRequest struct {
Name string `json:"name"`
Subnet string `json:"subnet"`
Subnets []string `json:"subnets"`
Models []string `json:"models"`
ExpiredAt int64 `json:"expiredAt"`
Quota float64 `json:"quota"`
Expand All @@ -54,27 +54,36 @@ type (
}
)

func (at *AddTokenRequest) ToToken() *model.Token {
var expiredAt time.Time
if at.ExpiredAt > 0 {
expiredAt = time.UnixMilli(at.ExpiredAt)
}
return &model.Token{
Name: model.EmptyNullString(at.Name),
Subnets: at.Subnets,
Models: at.Models,
ExpiredAt: expiredAt,
Quota: at.Quota,
}
}

func validateToken(token AddTokenRequest) error {
if token.Name == "" {
return errors.New("token name cannot be empty")
}
if len(token.Name) > 30 {
return errors.New("token name is too long")
}
if token.Subnet != "" {
if err := network.IsValidSubnets(token.Subnet); err != nil {
return fmt.Errorf("invalid subnet: %w", err)
}
if err := network.IsValidSubnets(token.Subnets); err != nil {
return fmt.Errorf("invalid subnet: %w", err)
}
return nil
}

func validateTokenStatus(token *model.Token) error {
if token.Status == model.TokenStatusExpired && !token.ExpiredAt.IsZero() && token.ExpiredAt.Before(time.Now()) {
return errors.New("token expired, please update token expired time or set to never expire")
}
if token.Status == model.TokenStatusExhausted && token.Quota > 0 && token.UsedAmount >= token.Quota {
return errors.New("token quota exhausted, please update token quota or set to unlimited quota")
func validateTokenUpdate(token AddTokenRequest) error {
if err := network.IsValidSubnets(token.Subnets); err != nil {
return fmt.Errorf("invalid subnet: %w", err)
}
return nil
}
Expand Down Expand Up @@ -223,7 +232,7 @@ func GetGroupToken(c *gin.Context) {
middleware.SuccessResponse(c, buildTokenResponse(token))
}

func AddToken(c *gin.Context) {
func AddGroupToken(c *gin.Context) {
group := c.Param("group")
var req AddTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
Expand All @@ -236,20 +245,9 @@ func AddToken(c *gin.Context) {
return
}

var expiredAt time.Time
if req.ExpiredAt > 0 {
expiredAt = time.UnixMilli(req.ExpiredAt)
}

token := &model.Token{
GroupID: group,
Name: model.EmptyNullString(req.Name),
Key: random.GenerateKey(),
ExpiredAt: expiredAt,
Quota: req.Quota,
Models: req.Models,
Subnet: req.Subnet,
}
token := req.ToToken()
token.GroupID = group
token.Key = random.GenerateKey()

if err := model.InsertToken(token, c.Query("auto_create_group") == "true"); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
Expand Down Expand Up @@ -336,29 +334,14 @@ func UpdateToken(c *gin.Context) {
return
}

if err := validateToken(req); err != nil {
if err := validateTokenUpdate(req); err != nil {
middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error())
return
}

token, err := model.GetTokenByID(id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

var expiredAt time.Time
if req.ExpiredAt > 0 {
expiredAt = time.UnixMilli(req.ExpiredAt)
}

token.Name = model.EmptyNullString(req.Name)
token.ExpiredAt = expiredAt
token.Quota = req.Quota
token.Models = req.Models
token.Subnet = req.Subnet
token := req.ToToken()

if err := model.UpdateToken(token); err != nil {
if err := model.UpdateToken(id, token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
Expand All @@ -380,29 +363,14 @@ func UpdateGroupToken(c *gin.Context) {
return
}

if err := validateToken(req); err != nil {
if err := validateTokenUpdate(req); err != nil {
middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error())
return
}

token, err := model.GetGroupTokenByID(group, id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

var expiredAt time.Time
if req.ExpiredAt > 0 {
expiredAt = time.UnixMilli(req.ExpiredAt)
}

token.Name = model.EmptyNullString(req.Name)
token.ExpiredAt = expiredAt
token.Quota = req.Quota
token.Models = req.Models
token.Subnet = req.Subnet
token := req.ToToken()

if err := model.UpdateToken(token); err != nil {
if err := model.UpdateGroupToken(id, group, token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
Expand All @@ -423,19 +391,6 @@ func UpdateTokenStatus(c *gin.Context) {
return
}

token, err := model.GetTokenByID(id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

if req.Status == model.TokenStatusEnabled {
if err := validateTokenStatus(token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
}

if err := model.UpdateTokenStatus(id, req.Status); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
Expand All @@ -458,19 +413,6 @@ func UpdateGroupTokenStatus(c *gin.Context) {
return
}

token, err := model.GetGroupTokenByID(group, id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

if req.Status == model.TokenStatusEnabled {
if err := validateTokenStatus(token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
}

if err := model.UpdateGroupTokenStatus(group, id, req.Status); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
Expand Down
28 changes: 13 additions & 15 deletions service/aiproxy/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,19 @@ func TokenAuth(c *gin.Context) {

SetLogTokenFields(log.Data, token, useInternalToken)

if token.Subnet != "" {
if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnet); err != nil {
abortLogWithMessage(c, http.StatusInternalServerError, err.Error())
return
} else if !ok {
abortLogWithMessage(c, http.StatusForbidden,
fmt.Sprintf("token (%s[%d]) can only be used in the specified subnet: %s, current ip: %s",
token.Name,
token.ID,
token.Subnet,
c.ClientIP(),
),
)
return
}
if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil {
abortLogWithMessage(c, http.StatusInternalServerError, err.Error())
return
} else if !ok {
abortLogWithMessage(c, http.StatusForbidden,
fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s",
token.Name,
token.ID,
token.Subnets,
c.ClientIP(),
),
)
return
}

var group *model.GroupCache
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type TokenCache struct {
Group string `json:"group" redis:"g"`
Key string `json:"-" redis:"-"`
Name string `json:"name" redis:"n"`
Subnet string `json:"subnet" redis:"s"`
Subnets redisStringSlice `json:"subnets" redis:"s"`
Models redisStringSlice `json:"models" redis:"m"`
ID int `json:"id" redis:"i"`
Status int `json:"status" redis:"st"`
Expand All @@ -78,7 +78,7 @@ func (t *Token) ToTokenCache() *TokenCache {
Key: t.Key,
Name: t.Name.String(),
Models: t.Models,
Subnet: t.Subnet,
Subnets: t.Subnets,
Status: t.Status,
ExpiredAt: redisTime(t.ExpiredAt),
Quota: t.Quota,
Expand Down
19 changes: 19 additions & 0 deletions service/aiproxy/model/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,25 @@ func DeleteGroupsByIDs(ids []string) (err error) {
})
}

func UpdateGroup(id string, group *Group) (err error) {
if id == "" {
return errors.New("group id is empty")
}
defer func() {
if err == nil {
if err := CacheDeleteGroup(id); err != nil {
log.Error("cache delete group failed: " + err.Error())
}
}
}()
result := DB.
Clauses(clause.Returning{}).
Where("id = ?", id).
Select("rpm_ratio", "rpm", "tpm_ratio", "tpm").
Updates(group)
return HandleUpdateResult(result, ErrGroupNotFound)
}

func UpdateGroupUsedAmountAndRequestCount(id string, amount float64, count int) (err error) {
group := &Group{ID: id}
defer func() {
Expand Down
Loading

0 comments on commit c6f9542

Please sign in to comment.