Skip to content

Commit

Permalink
refactor: token cache logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Dec 30, 2024
1 parent ca8b7ed commit bb5e032
Show file tree
Hide file tree
Showing 15 changed files with 414 additions and 193 deletions.
19 changes: 18 additions & 1 deletion common/crypto.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
package common

import "golang.org/x/crypto/bcrypt"
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"golang.org/x/crypto/bcrypt"
)

func GenerateHMACWithKey(key []byte, data string) string {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}

func GenerateHMAC(data string) string {
h := hmac.New(sha256.New, []byte(SessionSecret))
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}

func Password2Hash(password string) (string, error) {
passwordBytes := []byte(password)
Expand Down
210 changes: 185 additions & 25 deletions common/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package common

import (
"context"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"time"

"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)

var RDB *redis.Client
Expand Down Expand Up @@ -58,39 +62,167 @@ func RedisGet(key string) (string, error) {
return RDB.Get(ctx, key).Result()
}

func RedisExpire(key string, expiration time.Duration) error {
//func RedisExpire(key string, expiration time.Duration) error {
// ctx := context.Background()
// return RDB.Expire(ctx, key, expiration).Err()
//}
//
//func RedisGetEx(key string, expiration time.Duration) (string, error) {
// ctx := context.Background()
// return RDB.GetSet(ctx, key, expiration).Result()
//}

func RedisDel(key string) error {
ctx := context.Background()
return RDB.Expire(ctx, key, expiration).Err()
return RDB.Del(ctx, key).Err()
}

func RedisGetEx(key string, expiration time.Duration) (string, error) {
func RedisHDelObj(key string) error {
ctx := context.Background()
return RDB.GetSet(ctx, key, expiration).Result()
return RDB.HDel(ctx, key).Err()
}

func RedisDel(key string) error {
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
ctx := context.Background()
return RDB.Del(ctx, key).Err()

data := make(map[string]interface{})

// 使用反射遍历结构体字段
v := reflect.ValueOf(obj).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
value := v.Field(i)

// Skip DeletedAt field
if field.Type.String() == "gorm.DeletedAt" {
continue
}

// 处理指针类型
if value.Kind() == reflect.Ptr {
if value.IsNil() {
data[field.Name] = ""
continue
}
value = value.Elem()
}

// 处理布尔类型
if value.Kind() == reflect.Bool {
data[field.Name] = strconv.FormatBool(value.Bool())
continue
}

// 其他类型直接转换为字符串
data[field.Name] = fmt.Sprintf("%v", value.Interface())
}

txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
txn.Expire(ctx, key, expiration)

_, err := txn.Exec(ctx)
if err != nil {
return fmt.Errorf("failed to execute transaction: %w", err)
}
return nil
}

func RedisDecrease(key string, value int64) error {
func RedisHGetObj(key string, obj interface{}) error {
ctx := context.Background()

result, err := RDB.HGetAll(ctx, key).Result()
if err != nil {
return fmt.Errorf("failed to load hash from Redis: %w", err)
}

if len(result) == 0 {
return fmt.Errorf("key %s not found in Redis", key)
}

// Handle both pointer and non-pointer values
val := reflect.ValueOf(obj)
if val.Kind() != reflect.Ptr {
return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
}

v := val.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
}

t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldName := field.Name
if value, ok := result[fieldName]; ok {
fieldValue := v.Field(i)

// Handle pointer types
if fieldValue.Kind() == reflect.Ptr {
if value == "" {
continue
}
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue = fieldValue.Elem()
}

// Enhanced type handling for Token struct
switch fieldValue.Kind() {
case reflect.String:
fieldValue.SetString(value)
case reflect.Int, reflect.Int64:
intValue, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
}
fieldValue.SetInt(intValue)
case reflect.Bool:
boolValue, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
}
fieldValue.SetBool(boolValue)
case reflect.Struct:
// Special handling for gorm.DeletedAt
if fieldValue.Type().String() == "gorm.DeletedAt" {
if value != "" {
timeValue, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
}
fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
}
}
default:
return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
}
}
}

return nil
}

// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error {
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil {
// 失败则尝试直接减少
return RDB.DecrBy(context.Background(), key, value).Err()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get TTL: %w", err)
}

// 如果剩余生存时间大于0,则进行减少操作
// 只有在 key 存在且有 TTL 时才需要特殊处理
if ttl > 0 {
ctx := context.Background()
// 开始一个Redis事务
txn := RDB.TxPipeline()

// 减少余额
decrCmd := txn.DecrBy(ctx, key, value)
decrCmd := txn.IncrBy(ctx, key, delta)
if err := decrCmd.Err(); err != nil {
return err // 如果减少失败,则直接返回错误
}
Expand All @@ -101,26 +233,54 @@ func RedisDecrease(key string, value int64) error {
// 执行事务
_, err = txn.Exec(ctx)
return err
} else {
_ = RedisDel(key)
}
return nil
}

// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int) error {
ctx := context.Background()
func RedisHIncrBy(key, field string, delta int64) error {
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get TTL: %w", err)
}

// 检查键是否存在
exists, err := RDB.Exists(ctx, key).Result()
if err != nil {
if ttl > 0 {
ctx := context.Background()
txn := RDB.TxPipeline()

incrCmd := txn.HIncrBy(ctx, key, field, delta)
if err := incrCmd.Err(); err != nil {
return err
}

txn.Expire(ctx, key, ttl)

_, err = txn.Exec(ctx)
return err
}
if exists == 0 {
return fmt.Errorf("key does not exist") // 键不存在,返回错误
return nil
}

func RedisHSetField(key, field string, value interface{}) error {
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get TTL: %w", err)
}

// 键存在,执行INCRBY操作
result := RDB.IncrBy(ctx, key, int64(delta))
return result.Err()
if ttl > 0 {
ctx := context.Background()
txn := RDB.TxPipeline()

hsetCmd := txn.HSet(ctx, key, field, value)
if err := hsetCmd.Err(); err != nil {
return err
}

txn.Expire(ctx, key, ttl)

_, err = txn.Exec(ctx)
return err
}
return nil
}
7 changes: 6 additions & 1 deletion constant/cache_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ var (
UserId2StatusCacheSeconds = common.SyncFrequency
)

// Cache keys
const (
// Cache keys
UserGroupKeyFmt = "user_group:%d"
UserQuotaKeyFmt = "user_quota:%d"
UserEnabledKeyFmt = "user_enabled:%d"
UserUsernameKeyFmt = "user_name:%d"
)

const (
TokenFiledRemainQuota = "RemainQuota"
TokenFieldGroup = "Group"
)
3 changes: 0 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ func main() {
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.RedisEnabled {
go model.SyncTokenCache(common.SyncFrequency)
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
Expand Down
Loading

0 comments on commit bb5e032

Please sign in to comment.