diff --git a/common/user_groups.go b/common/user_groups.go new file mode 100644 index 000000000..67c3e7912 --- /dev/null +++ b/common/user_groups.go @@ -0,0 +1,23 @@ +package common + +import ( + "encoding/json" +) + +var UserUsableGroups = map[string]string{ + "default": "默认分组", + "vip": "vip分组", +} + +func UserUsableGroups2JSONString() string { + jsonBytes, err := json.Marshal(UserUsableGroups) + if err != nil { + SysError("error marshalling user groups: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateUserUsableGroupsByJSONString(jsonStr string) error { + UserUsableGroups = make(map[string]string) + return json.Unmarshal([]byte(jsonStr), &UserUsableGroups) +} diff --git a/controller/group.go b/controller/group.go index 2b2f6006f..2ee008b91 100644 --- a/controller/group.go +++ b/controller/group.go @@ -17,3 +17,18 @@ func GetGroups(c *gin.Context) { "data": groupNames, }) } + +func GetUserGroups(c *gin.Context) { + usableGroups := make(map[string]string) + for groupName, _ := range common.GroupRatio { + // UserUsableGroups contains the groups that the user can use + if _, ok := common.UserUsableGroups[groupName]; ok { + usableGroups[groupName] = common.UserUsableGroups[groupName] + } + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": usableGroups, + }) +} diff --git a/controller/token.go b/controller/token.go index 50a368f6f..0fc4b6cf3 100644 --- a/controller/token.go +++ b/controller/token.go @@ -135,6 +135,7 @@ func AddToken(c *gin.Context) { ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimits: token.ModelLimits, AllowIps: token.AllowIps, + Group: token.Group, } err = cleanToken.Insert() if err != nil { @@ -223,6 +224,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimits = token.ModelLimits cleanToken.AllowIps = token.AllowIps + cleanToken.Group = token.Group } err = cleanToken.Update() if err != nil { diff --git a/middleware/auth.go b/middleware/auth.go index 481960efa..8426b04ae 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -176,6 +176,7 @@ func TokenAuth() func(c *gin.Context) { c.Set("token_model_limit_enabled", false) } c.Set("allow_ips", token.GetIpLimitsMap()) + c.Set("token_group", token.Group) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("specific_channel_id", parts[1]) diff --git a/middleware/distributor.go b/middleware/distributor.go index 9b55cc2d2..0393d24f6 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -39,6 +39,15 @@ func Distribute() func(c *gin.Context) { return } userGroup, _ := model.CacheGetUserGroup(userId) + tokenGroup := c.GetString("token_group") + if tokenGroup != "" { + // check group in common.GroupRatio + if _, ok := common.GroupRatio[tokenGroup]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被禁用", tokenGroup)) + return + } + userGroup = tokenGroup + } c.Set("group", userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) diff --git a/model/option.go b/model/option.go index 4348919e4..04f952d4c 100644 --- a/model/option.go +++ b/model/option.go @@ -86,6 +86,7 @@ func InitOptionMap() { common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() + common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink @@ -303,6 +304,8 @@ func updateOptionMap(key string, value string) (err error) { err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) + case "UserUsableGroups": + err = common.UpdateUserUsableGroupsByJSONString(value) case "CompletionRatio": err = common.UpdateCompletionRatioByJSONString(value) case "ModelPrice": diff --git a/model/token.go b/model/token.go index 18aa2979e..dc769c842 100644 --- a/model/token.go +++ b/model/token.go @@ -25,6 +25,7 @@ type Token struct { ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"` AllowIps *string `json:"allow_ips" gorm:"default:''"` UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + Group string `json:"group" gorm:"default:''"` DeletedAt gorm.DeletedAt `gorm:"index"` } @@ -153,7 +154,8 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits", "allow_ips").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", + "model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error return err } diff --git a/router/api-router.go b/router/api-router.go index 68079396a..c38a3144f 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -39,6 +39,7 @@ func SetApiRouter(router *gin.Engine) { //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) userRoute.GET("/logout", controller.Logout) userRoute.GET("/epay/notify", controller.EpayNotify) + userRoute.GET("/groups", controller.GetUserGroups) selfRoute := userRoute.Group("/") selfRoute.Use(middleware.UserAuth()) diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index a61b757f9..1d875c618 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -23,6 +23,7 @@ const OperationSetting = () => { CompletionRatio: '', ModelPrice: '', GroupRatio: '', + UserUsableGroups: '', TopUpLink: '', ChatLink: '', ChatLink2: '', // 添加的新状态变量 @@ -62,6 +63,7 @@ const OperationSetting = () => { if ( item.key === 'ModelRatio' || item.key === 'GroupRatio' || + item.key === 'UserUsableGroups' || item.key === 'CompletionRatio' || item.key === 'ModelPrice' ) { diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index 64a189fd5..74b249ac7 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -8,14 +8,14 @@ import { } from '../helpers'; import { ITEMS_PER_PAGE } from '../constants'; -import { renderQuota } from '../helpers/render'; +import {renderGroup, renderQuota} from '../helpers/render'; import { Button, Dropdown, Form, Modal, Popconfirm, - Popover, + Popover, Space, SplitButtonGroup, Table, Tag, @@ -119,7 +119,12 @@ const TokensTable = () => { dataIndex: 'status', key: 'status', render: (text, record, index) => { - return