Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: token and group model update #5419

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions service/aiproxy/common/network/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@ package network
import (
"fmt"
"net"
"strings"
)

func splitSubnets(subnets string) []string {
res := strings.Split(subnets, ",")
for i := 0; i < len(res); i++ {
res[i] = strings.TrimSpace(res[i])
}
return res
}

func isValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
if err != nil {
Expand All @@ -30,17 +21,17 @@ func isIPInSubnet(ip string, subnet string) (bool, error) {
return ipNet.Contains(net.ParseIP(ip)), nil
}

func IsValidSubnets(subnets string) error {
for _, subnet := range splitSubnets(subnets) {
func IsValidSubnets(subnets []string) error {
for _, subnet := range subnets {
if err := isValidSubnet(subnet); err != nil {
return err
}
}
return nil
}

func IsIPInSubnets(ip string, subnets string) (bool, error) {
for _, subnet := range splitSubnets(subnets) {
func IsIPInSubnets(ip string, subnets []string) (bool, error) {
for _, subnet := range subnets {
if ok, err := isIPInSubnet(ip, subnet); err != nil {
return false, err
} else if ok {
Expand Down
25 changes: 25 additions & 0 deletions service/aiproxy/controller/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,28 @@ func CreateGroup(c *gin.Context) {
}
middleware.SuccessResponse(c, nil)
}

func UpdateGroup(c *gin.Context) {
group := c.Param("group")
if group == "" {
middleware.ErrorResponse(c, http.StatusOK, "invalid parameter")
return
}
req := CreateGroupRequest{}
err := json.NewDecoder(c.Request.Body).Decode(&req)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, "invalid parameter")
return
}
err = model.UpdateGroup(group, &model.Group{
RPMRatio: req.RPMRatio,
RPM: req.RPM,
TPMRatio: req.TPMRatio,
TPM: req.TPM,
})
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
middleware.SuccessResponse(c, nil)
}
118 changes: 30 additions & 88 deletions service/aiproxy/controller/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (t *TokenResponse) MarshalJSON() ([]byte, error) {
type (
AddTokenRequest struct {
Name string `json:"name"`
Subnet string `json:"subnet"`
Subnets []string `json:"subnets"`
Models []string `json:"models"`
ExpiredAt int64 `json:"expiredAt"`
Quota float64 `json:"quota"`
Expand All @@ -54,27 +54,36 @@ type (
}
)

func (at *AddTokenRequest) ToToken() *model.Token {
var expiredAt time.Time
if at.ExpiredAt > 0 {
expiredAt = time.UnixMilli(at.ExpiredAt)
}
return &model.Token{
Name: model.EmptyNullString(at.Name),
Subnets: at.Subnets,
Models: at.Models,
ExpiredAt: expiredAt,
Quota: at.Quota,
}
}

func validateToken(token AddTokenRequest) error {
if token.Name == "" {
return errors.New("token name cannot be empty")
}
if len(token.Name) > 30 {
return errors.New("token name is too long")
}
if token.Subnet != "" {
if err := network.IsValidSubnets(token.Subnet); err != nil {
return fmt.Errorf("invalid subnet: %w", err)
}
if err := network.IsValidSubnets(token.Subnets); err != nil {
return fmt.Errorf("invalid subnet: %w", err)
}
return nil
}

func validateTokenStatus(token *model.Token) error {
if token.Status == model.TokenStatusExpired && !token.ExpiredAt.IsZero() && token.ExpiredAt.Before(time.Now()) {
return errors.New("token expired, please update token expired time or set to never expire")
}
if token.Status == model.TokenStatusExhausted && token.Quota > 0 && token.UsedAmount >= token.Quota {
return errors.New("token quota exhausted, please update token quota or set to unlimited quota")
func validateTokenUpdate(token AddTokenRequest) error {
if err := network.IsValidSubnets(token.Subnets); err != nil {
return fmt.Errorf("invalid subnet: %w", err)
}
return nil
}
Expand Down Expand Up @@ -223,7 +232,7 @@ func GetGroupToken(c *gin.Context) {
middleware.SuccessResponse(c, buildTokenResponse(token))
}

func AddToken(c *gin.Context) {
func AddGroupToken(c *gin.Context) {
group := c.Param("group")
var req AddTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
Expand All @@ -236,20 +245,9 @@ func AddToken(c *gin.Context) {
return
}

var expiredAt time.Time
if req.ExpiredAt > 0 {
expiredAt = time.UnixMilli(req.ExpiredAt)
}

token := &model.Token{
GroupID: group,
Name: model.EmptyNullString(req.Name),
Key: random.GenerateKey(),
ExpiredAt: expiredAt,
Quota: req.Quota,
Models: req.Models,
Subnet: req.Subnet,
}
token := req.ToToken()
token.GroupID = group
token.Key = random.GenerateKey()

if err := model.InsertToken(token, c.Query("auto_create_group") == "true"); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
Expand Down Expand Up @@ -336,29 +334,14 @@ func UpdateToken(c *gin.Context) {
return
}

if err := validateToken(req); err != nil {
if err := validateTokenUpdate(req); err != nil {
middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error())
return
}

token, err := model.GetTokenByID(id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

var expiredAt time.Time
if req.ExpiredAt > 0 {
expiredAt = time.UnixMilli(req.ExpiredAt)
}

token.Name = model.EmptyNullString(req.Name)
token.ExpiredAt = expiredAt
token.Quota = req.Quota
token.Models = req.Models
token.Subnet = req.Subnet
token := req.ToToken()

if err := model.UpdateToken(token); err != nil {
if err := model.UpdateToken(id, token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
Expand All @@ -380,29 +363,14 @@ func UpdateGroupToken(c *gin.Context) {
return
}

if err := validateToken(req); err != nil {
if err := validateTokenUpdate(req); err != nil {
middleware.ErrorResponse(c, http.StatusOK, "parameter error: "+err.Error())
return
}

token, err := model.GetGroupTokenByID(group, id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

var expiredAt time.Time
if req.ExpiredAt > 0 {
expiredAt = time.UnixMilli(req.ExpiredAt)
}

token.Name = model.EmptyNullString(req.Name)
token.ExpiredAt = expiredAt
token.Quota = req.Quota
token.Models = req.Models
token.Subnet = req.Subnet
token := req.ToToken()

if err := model.UpdateToken(token); err != nil {
if err := model.UpdateGroupToken(id, group, token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
Expand All @@ -423,19 +391,6 @@ func UpdateTokenStatus(c *gin.Context) {
return
}

token, err := model.GetTokenByID(id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

if req.Status == model.TokenStatusEnabled {
if err := validateTokenStatus(token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
}

if err := model.UpdateTokenStatus(id, req.Status); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
Expand All @@ -458,19 +413,6 @@ func UpdateGroupTokenStatus(c *gin.Context) {
return
}

token, err := model.GetGroupTokenByID(group, id)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}

if req.Status == model.TokenStatusEnabled {
if err := validateTokenStatus(token); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
}
}

if err := model.UpdateGroupTokenStatus(group, id, req.Status); err != nil {
middleware.ErrorResponse(c, http.StatusOK, err.Error())
return
Expand Down
28 changes: 13 additions & 15 deletions service/aiproxy/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,19 @@ func TokenAuth(c *gin.Context) {

SetLogTokenFields(log.Data, token, useInternalToken)

if token.Subnet != "" {
if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnet); err != nil {
abortLogWithMessage(c, http.StatusInternalServerError, err.Error())
return
} else if !ok {
abortLogWithMessage(c, http.StatusForbidden,
fmt.Sprintf("token (%s[%d]) can only be used in the specified subnet: %s, current ip: %s",
token.Name,
token.ID,
token.Subnet,
c.ClientIP(),
),
)
return
}
if ok, err := network.IsIPInSubnets(c.ClientIP(), token.Subnets); err != nil {
abortLogWithMessage(c, http.StatusInternalServerError, err.Error())
return
} else if !ok {
abortLogWithMessage(c, http.StatusForbidden,
fmt.Sprintf("token (%s[%d]) can only be used in the specified subnets: %v, current ip: %s",
token.Name,
token.ID,
token.Subnets,
c.ClientIP(),
),
)
return
}

var group *model.GroupCache
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type TokenCache struct {
Group string `json:"group" redis:"g"`
Key string `json:"-" redis:"-"`
Name string `json:"name" redis:"n"`
Subnet string `json:"subnet" redis:"s"`
Subnets redisStringSlice `json:"subnets" redis:"s"`
Models redisStringSlice `json:"models" redis:"m"`
ID int `json:"id" redis:"i"`
Status int `json:"status" redis:"st"`
Expand All @@ -78,7 +78,7 @@ func (t *Token) ToTokenCache() *TokenCache {
Key: t.Key,
Name: t.Name.String(),
Models: t.Models,
Subnet: t.Subnet,
Subnets: t.Subnets,
Status: t.Status,
ExpiredAt: redisTime(t.ExpiredAt),
Quota: t.Quota,
Expand Down
19 changes: 19 additions & 0 deletions service/aiproxy/model/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,25 @@ func DeleteGroupsByIDs(ids []string) (err error) {
})
}

func UpdateGroup(id string, group *Group) (err error) {
if id == "" {
return errors.New("group id is empty")
}
defer func() {
if err == nil {
if err := CacheDeleteGroup(id); err != nil {
log.Error("cache delete group failed: " + err.Error())
}
}
}()
result := DB.
Clauses(clause.Returning{}).
Where("id = ?", id).
Select("rpm_ratio", "rpm", "tpm_ratio", "tpm").
Updates(group)
return HandleUpdateResult(result, ErrGroupNotFound)
}

func UpdateGroupUsedAmountAndRequestCount(id string, amount float64, count int) (err error) {
group := &Group{ID: id}
defer func() {
Expand Down
Loading