diff --git a/common/crypto.go b/common/crypto.go index 452284161..e17f4b01c 100644 --- a/common/crypto.go +++ b/common/crypto.go @@ -1,6 +1,23 @@ package common -import "golang.org/x/crypto/bcrypt" +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "golang.org/x/crypto/bcrypt" +) + +func GenerateHMACWithKey(key []byte, data string) string { + h := hmac.New(sha256.New, key) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} + +func GenerateHMAC(data string) string { + h := hmac.New(sha256.New, []byte(SessionSecret)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} func Password2Hash(password string) (string, error) { passwordBytes := []byte(password) diff --git a/common/redis.go b/common/redis.go index cc8035af7..02582ee21 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,11 +2,15 @@ package common import ( "context" + "errors" "fmt" "os" + "reflect" + "strconv" "time" "github.com/go-redis/redis/v8" + "gorm.io/gorm" ) var RDB *redis.Client @@ -58,39 +62,167 @@ func RedisGet(key string) (string, error) { return RDB.Get(ctx, key).Result() } -func RedisExpire(key string, expiration time.Duration) error { +//func RedisExpire(key string, expiration time.Duration) error { +// ctx := context.Background() +// return RDB.Expire(ctx, key, expiration).Err() +//} +// +//func RedisGetEx(key string, expiration time.Duration) (string, error) { +// ctx := context.Background() +// return RDB.GetSet(ctx, key, expiration).Result() +//} + +func RedisDel(key string) error { ctx := context.Background() - return RDB.Expire(ctx, key, expiration).Err() + return RDB.Del(ctx, key).Err() } -func RedisGetEx(key string, expiration time.Duration) (string, error) { +func RedisHDelObj(key string) error { ctx := context.Background() - return RDB.GetSet(ctx, key, expiration).Result() + return RDB.HDel(ctx, key).Err() } -func RedisDel(key string) error { +func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { ctx := context.Background() - return RDB.Del(ctx, key).Err() + + data := make(map[string]interface{}) + + // 使用反射遍历结构体字段 + v := reflect.ValueOf(obj).Elem() + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + value := v.Field(i) + + // Skip DeletedAt field + if field.Type.String() == "gorm.DeletedAt" { + continue + } + + // 处理指针类型 + if value.Kind() == reflect.Ptr { + if value.IsNil() { + data[field.Name] = "" + continue + } + value = value.Elem() + } + + // 处理布尔类型 + if value.Kind() == reflect.Bool { + data[field.Name] = strconv.FormatBool(value.Bool()) + continue + } + + // 其他类型直接转换为字符串 + data[field.Name] = fmt.Sprintf("%v", value.Interface()) + } + + txn := RDB.TxPipeline() + txn.HSet(ctx, key, data) + txn.Expire(ctx, key, expiration) + + _, err := txn.Exec(ctx) + if err != nil { + return fmt.Errorf("failed to execute transaction: %w", err) + } + return nil } -func RedisDecrease(key string, value int64) error { +func RedisHGetObj(key string, obj interface{}) error { + ctx := context.Background() + + result, err := RDB.HGetAll(ctx, key).Result() + if err != nil { + return fmt.Errorf("failed to load hash from Redis: %w", err) + } + + if len(result) == 0 { + return fmt.Errorf("key %s not found in Redis", key) + } + + // Handle both pointer and non-pointer values + val := reflect.ValueOf(obj) + if val.Kind() != reflect.Ptr { + return fmt.Errorf("obj must be a pointer to a struct, got %T", obj) + } + + v := val.Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface()) + } + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + fieldName := field.Name + if value, ok := result[fieldName]; ok { + fieldValue := v.Field(i) + + // Handle pointer types + if fieldValue.Kind() == reflect.Ptr { + if value == "" { + continue + } + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue = fieldValue.Elem() + } + + // Enhanced type handling for Token struct + switch fieldValue.Kind() { + case reflect.String: + fieldValue.SetString(value) + case reflect.Int, reflect.Int64: + intValue, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse int field %s: %w", fieldName, err) + } + fieldValue.SetInt(intValue) + case reflect.Bool: + boolValue, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err) + } + fieldValue.SetBool(boolValue) + case reflect.Struct: + // Special handling for gorm.DeletedAt + if fieldValue.Type().String() == "gorm.DeletedAt" { + if value != "" { + timeValue, err := time.Parse(time.RFC3339, value) + if err != nil { + return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err) + } + fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true})) + } + } + default: + return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName) + } + } + } + + return nil +} +// RedisIncr Add this function to handle atomic increments +func RedisIncr(key string, delta int64) error { // 检查键的剩余生存时间 ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() - if err != nil { - // 失败则尝试直接减少 - return RDB.DecrBy(context.Background(), key, value).Err() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to get TTL: %w", err) } - // 如果剩余生存时间大于0,则进行减少操作 + // 只有在 key 存在且有 TTL 时才需要特殊处理 if ttl > 0 { ctx := context.Background() // 开始一个Redis事务 txn := RDB.TxPipeline() // 减少余额 - decrCmd := txn.DecrBy(ctx, key, value) + decrCmd := txn.IncrBy(ctx, key, delta) if err := decrCmd.Err(); err != nil { return err // 如果减少失败,则直接返回错误 } @@ -101,26 +233,54 @@ func RedisDecrease(key string, value int64) error { // 执行事务 _, err = txn.Exec(ctx) return err - } else { - _ = RedisDel(key) } return nil } -// RedisIncr Add this function to handle atomic increments -func RedisIncr(key string, delta int) error { - ctx := context.Background() +func RedisHIncrBy(key, field string, delta int64) error { + ttlCmd := RDB.TTL(context.Background(), key) + ttl, err := ttlCmd.Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to get TTL: %w", err) + } - // 检查键是否存在 - exists, err := RDB.Exists(ctx, key).Result() - if err != nil { + if ttl > 0 { + ctx := context.Background() + txn := RDB.TxPipeline() + + incrCmd := txn.HIncrBy(ctx, key, field, delta) + if err := incrCmd.Err(); err != nil { + return err + } + + txn.Expire(ctx, key, ttl) + + _, err = txn.Exec(ctx) return err } - if exists == 0 { - return fmt.Errorf("key does not exist") // 键不存在,返回错误 + return nil +} + +func RedisHSetField(key, field string, value interface{}) error { + ttlCmd := RDB.TTL(context.Background(), key) + ttl, err := ttlCmd.Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("failed to get TTL: %w", err) } - // 键存在,执行INCRBY操作 - result := RDB.IncrBy(ctx, key, int64(delta)) - return result.Err() + if ttl > 0 { + ctx := context.Background() + txn := RDB.TxPipeline() + + hsetCmd := txn.HSet(ctx, key, field, value) + if err := hsetCmd.Err(); err != nil { + return err + } + + txn.Expire(ctx, key, ttl) + + _, err = txn.Exec(ctx) + return err + } + return nil } diff --git a/constant/cache_key.go b/constant/cache_key.go index d5a2c5aca..27cb3b755 100644 --- a/constant/cache_key.go +++ b/constant/cache_key.go @@ -9,10 +9,15 @@ var ( UserId2StatusCacheSeconds = common.SyncFrequency ) +// Cache keys const ( - // Cache keys UserGroupKeyFmt = "user_group:%d" UserQuotaKeyFmt = "user_quota:%d" UserEnabledKeyFmt = "user_enabled:%d" UserUsernameKeyFmt = "user_name:%d" ) + +const ( + TokenFiledRemainQuota = "RemainQuota" + TokenFieldGroup = "Group" +) diff --git a/main.go b/main.go index 763a953e0..f6fb8cfc3 100644 --- a/main.go +++ b/main.go @@ -80,9 +80,6 @@ func main() { common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) model.InitChannelCache() } - if common.RedisEnabled { - go model.SyncTokenCache(common.SyncFrequency) - } if common.MemoryCacheEnabled { go model.SyncOptions(common.SyncFrequency) go model.SyncChannelCache(common.SyncFrequency) diff --git a/model/cache.go b/model/cache.go index 0d87d1e11..b6102200f 100644 --- a/model/cache.go +++ b/model/cache.go @@ -1,99 +1,16 @@ package model import ( - "encoding/json" "errors" "fmt" "math/rand" "one-api/common" - "one-api/constant" "sort" "strings" "sync" "time" ) -// 仅用于定时同步缓存 -var token2UserId = make(map[string]int) -var token2UserIdLock sync.RWMutex - -func cacheSetToken(token *Token) error { - jsonBytes, err := json.Marshal(token) - if err != nil { - return err - } - err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second) - if err != nil { - common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error())) - return err - } - token2UserIdLock.Lock() - defer token2UserIdLock.Unlock() - token2UserId[token.Key] = token.UserId - return nil -} - -// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取 -func CacheGetTokenByKey(key string) (*Token, error) { - if !common.RedisEnabled { - return GetTokenByKey(key) - } - var token *Token - tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) - if err != nil { - // 如果缓存中不存在,则从数据库中获取 - token, err = GetTokenByKey(key) - if err != nil { - return nil, err - } - err = cacheSetToken(token) - return token, nil - } - // 如果缓存中存在,则续期时间 - err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second) - err = json.Unmarshal([]byte(tokenObjectString), &token) - return token, err -} - -func SyncTokenCache(frequency int) { - for { - time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing tokens from database") - token2UserIdLock.Lock() - // 从token2UserId中获取所有的key - var copyToken2UserId = make(map[string]int) - for s, i := range token2UserId { - copyToken2UserId[s] = i - } - token2UserId = make(map[string]int) - token2UserIdLock.Unlock() - - for key := range copyToken2UserId { - token, err := GetTokenByKey(key) - if err != nil { - // 如果数据库中不存在,则删除缓存 - common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error())) - //delete redis - err := common.RedisDel(fmt.Sprintf("token:%s", key)) - if err != nil { - common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error())) - } - } else { - // 如果数据库中存在,先检查redis - _, err = common.RedisGet(fmt.Sprintf("token:%s", key)) - if err != nil { - // 如果redis中不存在,则跳过 - continue - } - err = cacheSetToken(token) - if err != nil { - common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error())) - } - } - } - } -} - //func CacheGetUserGroup(id int) (group string, err error) { // if !common.RedisEnabled { // return GetUserGroup(id) diff --git a/model/log.go b/model/log.go index 06abeafa1..d172f057c 100644 --- a/model/log.go +++ b/model/log.go @@ -12,16 +12,6 @@ import ( "gorm.io/gorm" ) -var groupCol string - -func init() { - if common.UsingPostgreSQL { - groupCol = `"group"` - } else { - groupCol = "`group`" - } -} - type Log struct { Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"` UserId int `json:"user_id" gorm:"index"` diff --git a/model/main.go b/model/main.go index 6fe4e8113..5606176b5 100644 --- a/model/main.go +++ b/model/main.go @@ -13,6 +13,20 @@ import ( "time" ) +var groupCol string +var keyCol string + +func init() { + if common.UsingPostgreSQL { + groupCol = `"group"` + keyCol = `"key"` + + } else { + groupCol = "`group`" + keyCol = "`key`" + } +} + var DB *gorm.DB var LOG_DB *gorm.DB diff --git a/model/token.go b/model/token.go index 4d52bf03a..1596a8ddd 100644 --- a/model/token.go +++ b/model/token.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" relaycommon "one-api/relay/common" @@ -30,6 +31,10 @@ type Token struct { DeletedAt gorm.DeletedAt `gorm:"index"` } +func (token *Token) Clean() { + token.Key = "" +} + func (token *Token) GetIpLimitsMap() map[string]any { // delete empty spaces //split with \n @@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if key == "" { return nil, errors.New("未提供令牌") } - token, err = CacheGetTokenByKey(key) + token, err = GetTokenByKey(key, false) if err == nil { if token.Status == common.TokenStatusExhausted { keyPrefix := key[:3] @@ -129,21 +134,37 @@ func GetTokenById(id int) (*Token, error) { var err error = nil err = DB.First(&token, "id = ?", id).Error if err != nil { - if common.RedisEnabled { - go cacheSetToken(&token) - } + gopool.Go(func() { + if err := cacheSetToken(token); err != nil { + common.SysError("failed to update user status cache: " + err.Error()) + } + }) } return &token, err } -func GetTokenByKey(key string) (*Token, error) { - keyCol := "`key`" - if common.UsingPostgreSQL { - keyCol = `"key"` +func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) && token != nil { + gopool.Go(func() { + if err := cacheSetToken(*token); err != nil { + common.SysError("failed to update user status cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + // Try Redis first + token, err := cacheGetTokenByKey(key) + if err == nil { + return token, nil + } + // Don't return error - fall through to DB } - var token Token - err := DB.Where(keyCol+" = ?", key).First(&token).Error - return &token, err + fromDB = true + err = DB.Where(keyCol+" = ?", key).First(&token).Error + return token, err } func (token *Token) Insert() error { @@ -153,20 +174,48 @@ func (token *Token) Insert() error { } // Update Make sure your token's fields is completed, because this will update non-zero values -func (token *Token) Update() error { - var err error +func (token *Token) Update() (err error) { + defer func() { + if common.RedisEnabled && err == nil { + gopool.Go(func() { + err := cacheSetToken(*token) + if err != nil { + common.SysError("failed to update token cache: " + err.Error()) + } + }) + } + }() err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error return err } -func (token *Token) SelectUpdate() error { +func (token *Token) SelectUpdate() (err error) { + defer func() { + if common.RedisEnabled && err == nil { + gopool.Go(func() { + err := cacheSetToken(*token) + if err != nil { + common.SysError("failed to update token cache: " + err.Error()) + } + }) + } + }() // This can update zero values return DB.Model(token).Select("accessed_time", "status").Updates(token).Error } -func (token *Token) Delete() error { - var err error +func (token *Token) Delete() (err error) { + defer func() { + if common.RedisEnabled && err == nil { + gopool.Go(func() { + err := cacheDeleteToken(token.Key) + if err != nil { + common.SysError("failed to delete token cache: " + err.Error()) + } + }) + } + }() err = DB.Delete(token).Error return err } @@ -214,10 +263,16 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, quota int) (err error) { +func IncreaseTokenQuota(id int, key string, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + gopool.Go(func() { + err := cacheIncrTokenQuota(key, int64(quota)) + if err != nil { + common.SysError("failed to increase token quota: " + err.Error()) + } + }) if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, quota) return nil @@ -236,10 +291,16 @@ func increaseTokenQuota(id int, quota int) (err error) { return err } -func DecreaseTokenQuota(id int, quota int) (err error) { +func DecreaseTokenQuota(id int, key string, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + gopool.Go(func() { + err := cacheDecrTokenQuota(key, int64(quota)) + if err != nil { + common.SysError("failed to decrease token quota: " + err.Error()) + } + }) if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) return nil @@ -262,20 +323,22 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { if quota < 0 { return errors.New("quota 不能为负数!") } - if !relayInfo.IsPlayground { - token, err := GetTokenById(relayInfo.TokenId) - if err != nil { - return err - } - if !token.UnlimitedQuota && token.RemainQuota < quota { - return errors.New("令牌额度不足") - } + if relayInfo.IsPlayground { + return nil } - if !relayInfo.IsPlayground { - err := DecreaseTokenQuota(relayInfo.TokenId, quota) - if err != nil { - return err - } + //if relayInfo.TokenUnlimited { + // return nil + //} + token, err := GetTokenById(relayInfo.TokenId) + if err != nil { + return err + } + if !relayInfo.TokenUnlimited && token.RemainQuota < quota { + return errors.New("令牌额度不足") + } + err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + if err != nil { + return err } return nil } @@ -293,9 +356,9 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int if !relayInfo.IsPlayground { if quota > 0 { - err = DecreaseTokenQuota(relayInfo.TokenId, quota) + err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) } else { - err = IncreaseTokenQuota(relayInfo.TokenId, -quota) + err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) } if err != nil { return err diff --git a/model/token_cache.go b/model/token_cache.go new file mode 100644 index 000000000..99b762f51 --- /dev/null +++ b/model/token_cache.go @@ -0,0 +1,64 @@ +package model + +import ( + "fmt" + "one-api/common" + "one-api/constant" + "time" +) + +func cacheSetToken(token Token) error { + key := common.GenerateHMAC(token.Key) + token.Clean() + err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second) + if err != nil { + return err + } + return nil +} + +func cacheDeleteToken(key string) error { + key = common.GenerateHMAC(key) + err := common.RedisHDelObj(fmt.Sprintf("token:%s", key)) + if err != nil { + return err + } + return nil +} + +func cacheIncrTokenQuota(key string, increment int64) error { + key = common.GenerateHMAC(key) + err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment) + if err != nil { + return err + } + return nil +} + +func cacheDecrTokenQuota(key string, decrement int64) error { + return cacheIncrTokenQuota(key, -decrement) +} + +func cacheSetTokenField(key string, field string, value string) error { + key = common.GenerateHMAC(key) + err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value) + if err != nil { + return err + } + return nil +} + +// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取 +func cacheGetTokenByKey(key string) (*Token, error) { + hmacKey := common.GenerateHMAC(key) + if !common.RedisEnabled { + return nil, nil + } + var token Token + err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token) + if err != nil { + return nil, err + } + token.Key = key + return &token, nil +} diff --git a/model/user.go b/model/user.go index 2f06bd7e9..0171f3b65 100644 --- a/model/user.go +++ b/model/user.go @@ -252,7 +252,7 @@ func (user *User) Update(updatePassword bool) error { } // 更新缓存 - return updateUserCache(user) + return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) } func (user *User) Edit(updatePassword bool) error { @@ -281,7 +281,7 @@ func (user *User) Edit(updatePassword bool) error { } // 更新缓存 - return updateUserCache(user) + return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) } func (user *User) Delete() error { @@ -411,7 +411,7 @@ func IsAdmin(userId int) bool { func IsUserEnabled(id int, fromDB bool) (status bool, err error) { defer func() { // Update Redis cache asynchronously on successful DB read - if common.RedisEnabled { + if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserStatusCache(id, status); err != nil { common.SysError("failed to update user status cache: " + err.Error()) @@ -427,7 +427,7 @@ func IsUserEnabled(id int, fromDB bool) (status bool, err error) { } // Don't return error - fall through to DB } - + fromDB = true var user User err = DB.Where("id = ?", id).Select("status").Find(&user).Error if err != nil { @@ -453,7 +453,7 @@ func ValidateAccessToken(token string) (user *User) { func GetUserQuota(id int, fromDB bool) (quota int, err error) { defer func() { // Update Redis cache asynchronously on successful DB read - if common.RedisEnabled && err == nil { + if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { common.SysError("failed to update user quota cache: " + err.Error()) @@ -469,7 +469,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { // Don't return error - fall through to DB //common.SysError("failed to get user quota from cache: " + err.Error()) } - + fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error if err != nil { return 0, err @@ -492,7 +492,7 @@ func GetUserEmail(id int) (email string, err error) { func GetUserGroup(id int, fromDB bool) (group string, err error) { defer func() { // Update Redis cache asynchronously on successful DB read - if common.RedisEnabled && err == nil { + if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { common.SysError("failed to update user group cache: " + err.Error()) @@ -507,7 +507,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { } // Don't return error - fall through to DB } - + fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error if err != nil { return "", err @@ -521,7 +521,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { return errors.New("quota 不能为负数!") } gopool.Go(func() { - err := cacheIncrUserQuota(id, quota) + err := cacheIncrUserQuota(id, int64(quota)) if err != nil { common.SysError("failed to increase user quota: " + err.Error()) } @@ -546,7 +546,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { return errors.New("quota 不能为负数!") } gopool.Go(func() { - err := cacheDecrUserQuota(id, quota) + err := cacheDecrUserQuota(id, int64(quota)) if err != nil { common.SysError("failed to decrease user quota: " + err.Error()) } @@ -631,7 +631,7 @@ func updateUserRequestCount(id int, count int) { func GetUsernameById(id int, fromDB bool) (username string, err error) { defer func() { // Update Redis cache asynchronously on successful DB read - if common.RedisEnabled && err == nil { + if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { common.SysError("failed to update user name cache: " + err.Error()) @@ -646,7 +646,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { } // Don't return error - fall through to DB } - + fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error if err != nil { return "", err diff --git a/model/user_cache.go b/model/user_cache.go index 8c1129395..9dc7e899e 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -93,24 +93,24 @@ func updateUserNameCache(userId int, username string) error { } // updateUserCache updates all user cache fields -func updateUserCache(user *User) error { +func updateUserCache(userId int, username string, userGroup string, quota int, status int) error { if !common.RedisEnabled { return nil } - if err := updateUserGroupCache(user.Id, user.Group); err != nil { + if err := updateUserGroupCache(userId, userGroup); err != nil { return fmt.Errorf("update group cache: %w", err) } - if err := updateUserQuotaCache(user.Id, user.Quota); err != nil { + if err := updateUserQuotaCache(userId, quota); err != nil { return fmt.Errorf("update quota cache: %w", err) } - if err := updateUserStatusCache(user.Id, user.Status == common.UserStatusEnabled); err != nil { + if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil { return fmt.Errorf("update status cache: %w", err) } - if err := updateUserNameCache(user.Id, user.Username); err != nil { + if err := updateUserNameCache(userId, username); err != nil { return fmt.Errorf("update username cache: %w", err) } @@ -193,7 +193,7 @@ func getUserCache(userId int) (*userCache, error) { } // Add atomic quota operations -func cacheIncrUserQuota(userId int, delta int) error { +func cacheIncrUserQuota(userId int, delta int64) error { if !common.RedisEnabled { return nil } @@ -201,6 +201,6 @@ func cacheIncrUserQuota(userId int, delta int) error { return common.RedisIncr(key, delta) } -func cacheDecrUserQuota(userId int, delta int) error { +func cacheDecrUserQuota(userId int, delta int64) error { return cacheIncrUserQuota(userId, -delta) } diff --git a/model/utils.go b/model/utils.go index 3905e9511..e6b09aa5b 100644 --- a/model/utils.go +++ b/model/utils.go @@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) { } return false, err } + +func shouldUpdateRedis(fromDB bool, err error) bool { + return common.RedisEnabled && fromDB && err == nil +} diff --git a/relay/relay-audio.go b/relay/relay-audio.go index a29434576..4c23a8f8d 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -81,19 +81,9 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } - if userQuota-preConsumedQuota < 0 { - return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } + preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr } defer func() { if openaiErr != nil { diff --git a/relay/relay-text.go b/relay/relay-text.go index 6f251f6d7..960dfd07a 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -291,14 +291,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo } if preConsumedQuota > 0 { - err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } + err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } } return preConsumedQuota, userQuota, nil } diff --git a/service/quota.go b/service/quota.go index 820dcce5a..19c7c0579 100644 --- a/service/quota.go +++ b/service/quota.go @@ -23,7 +23,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return err } - token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-")) + token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false) if err != nil { return err }