From e84300f4aee6a01270add0fd69e41a4294abf6ad Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 01:07:37 +0800 Subject: [PATCH] chore: gopool --- controller/channel-test.go | 5 +- go.mod | 1 + go.sum | 4 + main.go | 5 +- model/log.go | 3 +- model/utils.go | 5 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 129 ++++++++++----------------- relay/channel/ollama/adaptor.go | 2 +- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/relay-openai.go | 17 ++-- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/zhipu/relay-zhipu.go | 13 +-- relay/channel/zhipu_4v/adaptor.go | 2 +- 14 files changed, 77 insertions(+), 115 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 90d02d617..fe279785a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "io" "math" "net/http" @@ -217,7 +218,7 @@ func testAllChannels(notify bool) error { if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } - go func() { + gopool.Go(func() { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() @@ -265,7 +266,7 @@ func testAllChannels(notify bool) error { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } - }() + }) return nil } diff --git a/go.mod b/go.mod index a9d4a1d3b..f97217b40 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect + github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index a77a89c70..f19b88cb8 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0= +github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -198,6 +200,7 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -206,6 +209,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main.go b/main.go index ed2ab2e28..959b795f8 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "embed" "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" @@ -91,10 +92,10 @@ func main() { go controller.AutomaticallyTestChannels(frequency) } if common.IsMasterNode && constant.UpdateTask { - common.SafeGoroutine(func() { + gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() }) - common.SafeGoroutine(func() { + gopool.Go(func() { controller.UpdateTaskBulk() }) } diff --git a/model/log.go b/model/log.go index cea5b98dd..85c53b1fb 100644 --- a/model/log.go +++ b/model/log.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" "strings" @@ -87,7 +88,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke common.LogError(ctx, "failed to record log: "+err.Error()) } if common.DataExportEnabled { - common.SafeGoroutine(func() { + gopool.Go(func() { LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) }) } diff --git a/model/utils.go b/model/utils.go index 44bfbb9e2..3905e9511 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,6 +2,7 @@ package model import ( "errors" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" "sync" @@ -28,12 +29,12 @@ func init() { } func InitBatchUpdater() { - go func() { + gopool.Go(func() { for { time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) batchUpdate() } - }() + }) } func addNewRecord(type_ int, id int, value int) { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 98728a0a5..ff9d5330a 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -84,7 +84,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = aliEmbeddingHandler(c, resp) default: if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0d707157d..031f82537 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,12 +8,10 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" - "time" ) func stopReasonClaude2OpenAI(reason string) string { @@ -332,91 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. responseText := "" createdTime := common.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil + scanner.Split(bufio.ScanLines) + service.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + info.SetFirstResponseTime() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue } - if atEOF { - return len(data), data, nil + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + var claudeResponse ClaudeResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue } - return 0, nil, nil - }) - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } + + response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + if response == nil { + continue } - stopChan <- true - }() - isFirst := true - service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } + if requestMode == RequestModeCompletion { + responseText += claudeResponse.Completion + responseId = response.Id + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + responseId = claudeResponse.Message.Id + info.UpstreamModelName = claudeResponse.Message.Model + usage.PromptTokens = claudeUsage.InputTokens + } else if claudeResponse.Type == "content_block_delta" { + responseText += claudeResponse.Delta.Text + } else if claudeResponse.Type == "message_delta" { + usage.CompletionTokens = claudeUsage.OutputTokens + usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { - response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) - if response == nil { - return true - } - if requestMode == RequestModeCompletion { - responseText += claudeResponse.Completion - responseId = response.Id } else { - if claudeResponse.Type == "message_start" { - // message_start, 获取usage - responseId = claudeResponse.Message.Id - info.UpstreamModelName = claudeResponse.Message.Model - usage.PromptTokens = claudeUsage.InputTokens - } else if claudeResponse.Type == "content_block_delta" { - responseText += claudeResponse.Delta.Text - } else if claudeResponse.Type == "message_delta" { - usage.CompletionTokens = claudeUsage.OutputTokens - usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens - } else if claudeResponse.Type == "content_block_start" { - - } else { - return true - } + continue } - //response.Id = responseId - response.Id = responseId - response.Created = createdTime - response.Model = info.UpstreamModelName + } + //response.Id = responseId + response.Id = responseId + response.Created = createdTime + response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) - if err != nil { - common.SysError(err.Error()) - } - return true - case <-stopChan: - return false + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, "send_stream_response_failed: "+err.Error()) } - }) + } + if requestMode == RequestModeCompletion { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { @@ -435,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } service.Done(c) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() return nil, usage } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 540ec8596..408db6aae 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -64,7 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 2aa743f53..4388efd6d 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -145,7 +145,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = OpenaiTTSHandler(c, resp, info) default: if info.IsStream { - err, usage = OpenaiStreamHandler(c, resp, info) + err, usage = OaiStreamHandler(c, resp, info) } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 45e5defb7..807f4b18f 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "io" "net/http" @@ -18,8 +19,8 @@ import ( "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - hasStreamUsage := false +func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + containStreamUsage := false responseId := "" var createAt int64 = 0 var systemFingerprint string @@ -41,7 +42,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. stopChan := make(chan bool) defer close(stopChan) - go func() { + gopool.Go(func() { for scanner.Scan() { info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) @@ -62,7 +63,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } common.SafeSendBool(stopChan, true) - }() + }) select { case <-ticker.C: @@ -91,7 +92,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. model = streamResponse.Model if service.ValidUsage(streamResponse.Usage) { usage = streamResponse.Usage - hasStreamUsage = true + containStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -115,7 +116,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. model = streamResponse.Model if service.ValidUsage(streamResponse.Usage) { usage = streamResponse.Usage - hasStreamUsage = true + containStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -155,12 +156,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } - if !hasStreamUsage { + if !containStreamUsage { usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } - if info.ShouldIncludeUsage && !hasStreamUsage { + if info.ShouldIncludeUsage && !containStreamUsage { response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) service.ObjectData(c, response) diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index d3ed222cc..e9d07fbea 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -58,7 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 5ef9d7ab8..aaf3c5dd4 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage *dto.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { - return i + 2, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) dataChan := make(chan string) metaChan := make(chan string) stopChan := make(chan bool) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index b34b756ba..5e0906efe 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -59,7 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) }