Skip to content

Commit

Permalink
feat: baidu keyhelp and doubao mega (#5409)
Browse files Browse the repository at this point in the history
* feat: try support doubao tts icl

* fix: doubao mega

* fix: add baidu key help
  • Loading branch information
zijiren233 authored Feb 26, 2025
1 parent 813e7a0 commit 692ab3a
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 11 deletions.
28 changes: 28 additions & 0 deletions service/aiproxy/relay/adaptor/baidu/key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package baidu

import (
"errors"
"strings"

"github.com/labring/sealos/service/aiproxy/relay/adaptor"
)

var _ adaptor.KeyValidator = (*Adaptor)(nil)

func (a *Adaptor) ValidateKey(key string) error {
_, _, err := getClientIDAndSecret(key)
return err
}

func (a *Adaptor) KeyHelp() string {
return "client_id|client_secret"
}

// key格式: client_id|client_secret
func getClientIDAndSecret(key string) (string, string, error) {
parts := strings.Split(key, "|")
if len(parts) != 2 {
return "", "", errors.New("invalid key format")
}
return parts[0], parts[1], nil
}
9 changes: 4 additions & 5 deletions service/aiproxy/relay/adaptor/baidu/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -52,14 +51,14 @@ func GetAccessToken(ctx context.Context, apiKey string) (string, error) {
}

func getBaiduAccessTokenHelper(ctx context.Context, apiKey string) (*AccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
clientID, clientSecret, err := getClientIDAndSecret(apiKey)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx,
http.MethodPost,
fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
parts[0], parts[1]),
clientID, clientSecret),
nil)
if err != nil {
return nil, err
Expand Down
28 changes: 28 additions & 0 deletions service/aiproxy/relay/adaptor/baiduv2/key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package baiduv2

import (
"errors"
"strings"

"github.com/labring/sealos/service/aiproxy/relay/adaptor"
)

var _ adaptor.KeyValidator = (*Adaptor)(nil)

func (a *Adaptor) ValidateKey(key string) error {
_, _, err := getAKAndSK(key)
return err
}

func (a *Adaptor) KeyHelp() string {
return "ak|sk"
}

// key格式: ak|sk
func getAKAndSK(key string) (string, string, error) {
parts := strings.Split(key, "|")
if len(parts) != 2 {
return "", "", errors.New("invalid key format")
}
return parts[0], parts[1], nil
}
8 changes: 4 additions & 4 deletions service/aiproxy/relay/adaptor/baiduv2/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ func GetBearerToken(ctx context.Context, apiKey string) (*TokenResponse, error)
}

func getBaiduAccessTokenHelper(ctx context.Context, apiKey string) (*TokenResponse, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
ak, sk, err := getAKAndSK(apiKey)
if err != nil {
return nil, err
}
authorization := generateAuthorizationString(parts[0], parts[1])
authorization := generateAuthorizationString(ak, sk)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://iam.bj.baidubce.com/v1/BCE-BEARER/token", nil)
if err != nil {
return nil, err
Expand Down
13 changes: 11 additions & 2 deletions service/aiproxy/relay/adaptor/doubaoaudio/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"io"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/google/uuid"
Expand Down Expand Up @@ -78,11 +79,18 @@ func ConvertTTSRequest(meta *meta.Meta, req *http.Request) (string, http.Header,
return "", nil, nil, err
}

cluster := "volcano_tts"
textType := "ssml"
if strings.HasPrefix(request.Voice, "S_") {
cluster = "volcano_mega"
textType = "plain"
}

doubaoRequest := DoubaoTTSRequest{
App: AppConfig{
AppID: appID,
Token: token,
Cluster: "volcano_tts",
Cluster: cluster,
},
User: UserConfig{
UID: meta.RequestID,
Expand All @@ -93,7 +101,7 @@ func ConvertTTSRequest(meta *meta.Meta, req *http.Request) (string, http.Header,
Request: RequestConfig{
ReqID: uuid.New().String(),
Text: request.Input,
TextType: "ssml", // plain
TextType: textType,
Operation: "submit",
},
}
Expand Down Expand Up @@ -220,6 +228,7 @@ func gzipDecompress(input []byte) ([]byte, error) {
if err != nil {
return nil, err
}
defer r.Close()
out, err := io.ReadAll(r)
if err != nil {
return nil, err
Expand Down

0 comments on commit 692ab3a

Please sign in to comment.