From 40ec4668c33ff35fc757ca3834e9c618d6652859 Mon Sep 17 00:00:00 2001 From: zijiren <84728412+zijiren233@users.noreply.github.com> Date: Tue, 25 Feb 2025 15:19:46 +0800 Subject: [PATCH] feat: reduce the number of stream memory copies (#5402) * fix: generalize the openai data prefix * feat: reduce the number of stream memory copies * fix: lint * feat: reduce the number of stream memory copies * fix: lint * fix: reset buf when usage * chore: mv event to render * fix: ignore samgrep error --- service/aiproxy/common/custom-event.go | 67 -------- service/aiproxy/common/render/event.go | 51 ++++++ service/aiproxy/common/render/render.go | 9 +- .../aiproxy/relay/adaptor/aws/claude/main.go | 2 +- .../aiproxy/relay/adaptor/aws/llama3/main.go | 3 +- service/aiproxy/relay/adaptor/coze/adaptor.go | 4 +- service/aiproxy/relay/adaptor/coze/main.go | 10 +- service/aiproxy/relay/adaptor/openai/main.go | 150 ++++++++++++------ service/aiproxy/relay/adaptor/openai/model.go | 1 + service/aiproxy/relay/adaptor/openai/token.go | 21 ++- 10 files changed, 179 insertions(+), 139 deletions(-) delete mode 100644 service/aiproxy/common/custom-event.go create mode 100644 service/aiproxy/common/render/event.go diff --git a/service/aiproxy/common/custom-event.go b/service/aiproxy/common/custom-event.go deleted file mode 100644 index a7a76219fb9..00000000000 --- a/service/aiproxy/common/custom-event.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2014 Manu Martinez-Almeida. All rights reserved. -// Use of this source code is governed by a MIT style -// license that can be found in the LICENSE file. - -package common - -import ( - "io" - "net/http" - "strings" - - "github.com/labring/sealos/service/aiproxy/common/conv" -) - -// Server-Sent Events -// W3C Working Draft 29 October 2009 -// http://www.w3.org/TR/2009/WD-eventsource-20091029/ - -var ( - contentType = []string{"text/event-stream"} - noCache = []string{"no-cache"} -) - -var dataReplacer = strings.NewReplacer( - "\n", "\ndata:", - "\r", "\\r") - -type CustomEvent struct { - Data string - Event string - ID string - Retry uint -} - -func encode(writer io.Writer, event CustomEvent) error { - return writeData(writer, event.Data) -} - -const nn = "\n\n" - -var nnBytes = conv.StringToBytes(nn) - -func writeData(w io.Writer, data string) error { - _, err := dataReplacer.WriteString(w, data) - if err != nil { - return err - } - if strings.HasPrefix(data, "data") { - _, err := w.Write(nnBytes) - return err - } - return nil -} - -func (r CustomEvent) Render(w http.ResponseWriter) error { - r.WriteContentType(w) - return encode(w, r) -} - -func (r CustomEvent) WriteContentType(w http.ResponseWriter) { - header := w.Header() - header["Content-Type"] = contentType - - if _, exist := header["Cache-Control"]; !exist { - header["Cache-Control"] = noCache - } -} diff --git a/service/aiproxy/common/render/event.go b/service/aiproxy/common/render/event.go new file mode 100644 index 00000000000..3a4a05c848b --- /dev/null +++ b/service/aiproxy/common/render/event.go @@ -0,0 +1,51 @@ +package render + +import ( + "net/http" + + "github.com/labring/sealos/service/aiproxy/common/conv" +) + +var ( + contentType = []string{"text/event-stream"} + noCache = []string{"no-cache"} +) + +type OpenAISSE struct { + Data string +} + +const ( + nn = "\n\n" + data = "data: " +) + +var ( + nnBytes = conv.StringToBytes(nn) + dataBytes = conv.StringToBytes(data) +) + +func (r OpenAISSE) Render(w http.ResponseWriter) error { + r.WriteContentType(w) + + for _, bytes := range [][]byte{ + dataBytes, + conv.StringToBytes(r.Data), + nnBytes, + } { + // nosemgrep: go.lang.security.audit.xss.no-direct-write-to-responsewriter.no-direct-write-to-responsewriter + if _, err := w.Write(bytes); err != nil { + return err + } + } + return nil +} + +func (r OpenAISSE) WriteContentType(w http.ResponseWriter) { + header := w.Header() + header["Content-Type"] = contentType + + if _, exist := header["Cache-Control"]; !exist { + header["Cache-Control"] = noCache + } +} diff --git a/service/aiproxy/common/render/render.go b/service/aiproxy/common/render/render.go index ff4e8f07b77..d2328de155b 100644 --- a/service/aiproxy/common/render/render.go +++ b/service/aiproxy/common/render/render.go @@ -3,11 +3,9 @@ package render import ( "errors" "fmt" - "strings" "github.com/gin-gonic/gin" json "github.com/json-iterator/go" - "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/conv" ) @@ -20,9 +18,7 @@ func StringData(c *gin.Context, str string) { if c.IsAborted() { return } - str = strings.TrimPrefix(str, "data:") - // str = strings.TrimSuffix(str, "\r") - c.Render(-1, common.CustomEvent{Data: "data: " + strings.TrimSpace(str)}) + c.Render(-1, OpenAISSE{Data: str}) c.Writer.Flush() } @@ -37,7 +33,8 @@ func ObjectData(c *gin.Context, object any) error { if err != nil { return fmt.Errorf("error marshalling object: %w", err) } - StringData(c, conv.BytesToString(jsonData)) + c.Render(-1, OpenAISSE{Data: conv.BytesToString(jsonData)}) + c.Writer.Flush() return nil } diff --git a/service/aiproxy/relay/adaptor/aws/claude/main.go b/service/aiproxy/relay/adaptor/aws/claude/main.go index 7a9e19f1d61..55b65b60db2 100644 --- a/service/aiproxy/relay/adaptor/aws/claude/main.go +++ b/service/aiproxy/relay/adaptor/aws/claude/main.go @@ -202,7 +202,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus c.Stream(func(_ io.Writer) bool { event, ok := <-stream.Events() if !ok { - render.StringData(c, "[DONE]") + render.Done(c) return false } diff --git a/service/aiproxy/relay/adaptor/aws/llama3/main.go b/service/aiproxy/relay/adaptor/aws/llama3/main.go index d353096178f..4ba3147cf36 100644 --- a/service/aiproxy/relay/adaptor/aws/llama3/main.go +++ b/service/aiproxy/relay/adaptor/aws/llama3/main.go @@ -14,7 +14,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/random" "github.com/labring/sealos/service/aiproxy/common/render" "github.com/labring/sealos/service/aiproxy/middleware" @@ -209,7 +208,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus c.Stream(func(_ io.Writer) bool { event, ok := <-stream.Events() if !ok { - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + render.Done(c) return false } diff --git a/service/aiproxy/relay/adaptor/coze/adaptor.go b/service/aiproxy/relay/adaptor/coze/adaptor.go index 50687f4f880..1c8306c9e01 100644 --- a/service/aiproxy/relay/adaptor/coze/adaptor.go +++ b/service/aiproxy/relay/adaptor/coze/adaptor.go @@ -82,9 +82,9 @@ func (a *Adaptor) DoRequest(_ *meta.Meta, _ *gin.Context, req *http.Request) (*h func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (usage *relaymodel.Usage, err *relaymodel.ErrorWithStatusCode) { var responseText *string if utils.IsStreamResponse(resp) { - err, responseText = StreamHandler(c, resp) + err, responseText = StreamHandler(meta, c, resp) } else { - err, responseText = Handler(c, resp, meta.InputTokens, meta.ActualModel) + err, responseText = Handler(meta, c, resp) } if responseText != nil { usage = openai.ResponseText2Usage(*responseText, meta.ActualModel, meta.InputTokens) diff --git a/service/aiproxy/relay/adaptor/coze/main.go b/service/aiproxy/relay/adaptor/coze/main.go index d2b644f3c40..e66c99bb2cf 100644 --- a/service/aiproxy/relay/adaptor/coze/main.go +++ b/service/aiproxy/relay/adaptor/coze/main.go @@ -14,6 +14,7 @@ import ( "github.com/labring/sealos/service/aiproxy/relay/adaptor/coze/constant/messagetype" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" "github.com/labring/sealos/service/aiproxy/relay/constant" + "github.com/labring/sealos/service/aiproxy/relay/meta" "github.com/labring/sealos/service/aiproxy/relay/model" ) @@ -85,7 +86,7 @@ func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse { return &fullTextResponse } -func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) { +func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) { defer resp.Body.Close() log := middleware.GetLogger(c) @@ -96,7 +97,6 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC scanner.Split(bufio.ScanLines) common.SetEventStreamHeaders(c) - var modelName string for scanner.Scan() { data := scanner.Bytes() @@ -124,7 +124,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC for _, choice := range response.Choices { responseText += conv.AsString(choice.Delta.Content) } - response.Model = modelName + response.Model = meta.OriginModel response.Created = createdTime _ = render.ObjectData(c, response) @@ -139,7 +139,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return nil, &responseText } -func Handler(c *gin.Context, resp *http.Response, _ int, modelName string) (*model.ErrorWithStatusCode, *string) { +func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) { defer resp.Body.Close() log := middleware.GetLogger(c) @@ -159,7 +159,7 @@ func Handler(c *gin.Context, resp *http.Response, _ int, modelName string) (*mod }, nil } fullTextResponse := ResponseCoze2OpenAI(&cozeResponse) - fullTextResponse.Model = modelName + fullTextResponse.Model = meta.OriginModel jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/service/aiproxy/relay/adaptor/openai/main.go b/service/aiproxy/relay/adaptor/openai/main.go index ab9229d5597..eb5f7341ba7 100644 --- a/service/aiproxy/relay/adaptor/openai/main.go +++ b/service/aiproxy/relay/adaptor/openai/main.go @@ -2,9 +2,12 @@ package openai import ( "bufio" + "bytes" "io" "net/http" + "slices" "strings" + "sync" "github.com/gin-gonic/gin" json "github.com/json-iterator/go" @@ -19,11 +22,16 @@ import ( ) const ( - DataPrefix = "data: " + DataPrefix = "data:" Done = "[DONE]" DataPrefixLength = len(DataPrefix) ) +var ( + DataPrefixBytes = conv.StringToBytes(DataPrefix) + DoneBytes = conv.StringToBytes(Done) +) + var stdjson = json.ConfigCompatibleWithStandardLibrary type UsageAndChoicesResponse struct { @@ -31,14 +39,38 @@ type UsageAndChoicesResponse struct { Choices []*ChatCompletionsStreamResponseChoice } +const scannerBufferSize = 2 * bufio.MaxScanTokenSize + +var scannerBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, scannerBufferSize) + return &buf + }, +} + +//nolint:forcetypeassert +func getScannerBuffer() *[]byte { + return scannerBufferPool.Get().(*[]byte) +} + +func putScannerBuffer(buf *[]byte) { + if cap(*buf) != scannerBufferSize { + return + } + scannerBufferPool.Put(buf) +} + func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) { defer resp.Body.Close() log := middleware.GetLogger(c) - responseText := "" + responseText := strings.Builder{} + scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) + buf := getScannerBuffer() + defer putScannerBuffer(buf) + scanner.Buffer(*buf, cap(*buf)) var usage *model.Usage @@ -51,36 +83,40 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model } for scanner.Scan() { - data := scanner.Text() + data := scanner.Bytes() if len(data) < DataPrefixLength { // ignore blank line or wrong format continue } - if data[:DataPrefixLength] != DataPrefix { + if !slices.Equal(data[:DataPrefixLength], DataPrefixBytes) { continue } - data = data[DataPrefixLength:] - if strings.HasPrefix(data, Done) { + data = bytes.TrimSpace(data[DataPrefixLength:]) + if slices.Equal(data, DoneBytes) { break } + switch meta.Mode { case relaymode.ChatCompletions: var streamResponse UsageAndChoicesResponse - err := json.Unmarshal(conv.StringToBytes(data), &streamResponse) + err := json.Unmarshal(data, &streamResponse) if err != nil { log.Error("error unmarshalling stream response: " + err.Error()) continue } if streamResponse.Usage != nil { usage = streamResponse.Usage + responseText.Reset() } for _, choice := range streamResponse.Choices { - responseText += choice.Delta.StringContent() + if usage == nil { + responseText.WriteString(choice.Delta.StringContent()) + } if choice.Delta.ReasoningContent != "" { hasReasoningContent = true } } respMap := make(map[string]any) - err = json.Unmarshal(conv.StringToBytes(data), &respMap) + err = json.Unmarshal(data, &respMap) if err != nil { log.Error("error unmarshalling stream response: " + err.Error()) continue @@ -97,18 +133,29 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model _ = render.ObjectData(c, respMap) case relaymode.Completions: var streamResponse CompletionsStreamResponse - err := json.Unmarshal(conv.StringToBytes(data), &streamResponse) + err := json.Unmarshal(data, &streamResponse) if err != nil { log.Error("error unmarshalling stream response: " + err.Error()) continue } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } if streamResponse.Usage != nil { usage = streamResponse.Usage + responseText.Reset() + } else { + for _, choice := range streamResponse.Choices { + responseText.WriteString(choice.Text) + } + } + respMap := make(map[string]any) + err = json.Unmarshal(data, &respMap) + if err != nil { + log.Error("error unmarshalling stream response: " + err.Error()) + continue } - render.StringData(c, data) + if _, ok := respMap["model"]; ok && meta.OriginModel != "" { + respMap["model"] = meta.OriginModel + } + _ = render.ObjectData(c, respMap) } } @@ -118,8 +165,8 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model render.Done(c) - if usage == nil || (usage.TotalTokens == 0 && responseText != "") { - usage = ResponseText2Usage(responseText, meta.ActualModel, meta.InputTokens) + if usage == nil || (usage.TotalTokens == 0 && responseText.Len() > 0) { + usage = ResponseText2Usage(responseText.String(), meta.ActualModel, meta.InputTokens) } if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens @@ -133,40 +180,41 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model // renderCallback maybe reuse data, so don't modify data func StreamSplitThink(data map[string]any, thinkSplitter *splitter.Splitter, renderCallback func(data map[string]any)) { choices, ok := data["choices"].([]any) + // only support one choice + if !ok || len(choices) != 1 { + renderCallback(data) + return + } + choice := choices[0] + choiceMap, ok := choice.(map[string]any) if !ok { + renderCallback(data) return } - for _, choice := range choices { - choiceMap, ok := choice.(map[string]any) - if !ok { - renderCallback(data) - continue - } - delta, ok := choiceMap["delta"].(map[string]any) - if !ok { - renderCallback(data) - continue - } - content, ok := delta["content"].(string) - if !ok { - renderCallback(data) - continue - } - think, remaining := thinkSplitter.Process(conv.StringToBytes(content)) - if len(think) == 0 && len(remaining) == 0 { - renderCallback(data) - continue - } - if len(think) > 0 { - delta["content"] = "" - delta["reasoning_content"] = conv.BytesToString(think) - renderCallback(data) - } - if len(remaining) > 0 { - delta["content"] = conv.BytesToString(remaining) - delta["reasoning_content"] = "" - renderCallback(data) - } + delta, ok := choiceMap["delta"].(map[string]any) + if !ok { + renderCallback(data) + return + } + content, ok := delta["content"].(string) + if !ok { + renderCallback(data) + return + } + think, remaining := thinkSplitter.Process(conv.StringToBytes(content)) + if len(think) == 0 && len(remaining) == 0 { + renderCallback(data) + return + } + if len(think) > 0 { + delta["content"] = "" + delta["reasoning_content"] = conv.BytesToString(think) + renderCallback(data) + } + if len(remaining) > 0 { + delta["content"] = conv.BytesToString(remaining) + delete(delta, "reasoning_content") + renderCallback(data) } } @@ -203,11 +251,13 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage if err != nil { return nil, ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + var textResponse SlimTextResponse err = json.Unmarshal(responseBody, &textResponse) if err != nil { return nil, ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } + if textResponse.Error.Type != "" { return nil, ErrorWrapperWithMessage(textResponse.Error.Message, textResponse.Error.Code, http.StatusBadRequest) } @@ -215,6 +265,10 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { + if choice.Text != "" { + completionTokens += CountTokenText(choice.Text, meta.ActualModel) + continue + } completionTokens += CountTokenText(choice.Message.StringContent(), meta.ActualModel) } textResponse.Usage = model.Usage{ diff --git a/service/aiproxy/relay/adaptor/openai/model.go b/service/aiproxy/relay/adaptor/openai/model.go index ad0d579bf56..6ba7de31f6d 100644 --- a/service/aiproxy/relay/adaptor/openai/model.go +++ b/service/aiproxy/relay/adaptor/openai/model.go @@ -83,6 +83,7 @@ type TextResponseChoice struct { FinishReason string `json:"finish_reason"` Message model.Message `json:"message"` Index int `json:"index"` + Text string `json:"text"` } type TextResponse struct { diff --git a/service/aiproxy/relay/adaptor/openai/token.go b/service/aiproxy/relay/adaptor/openai/token.go index 2b3df7d2be7..19c0e7fb190 100644 --- a/service/aiproxy/relay/adaptor/openai/token.go +++ b/service/aiproxy/relay/adaptor/openai/token.go @@ -79,7 +79,10 @@ func CountTokenMessages(messages []*model.Message, model string) int { tokenNum += getTokenNum(tokenEncoder, v) case []any: for _, it := range v { - m := it.(map[string]any) + m, ok := it.(map[string]any) + if !ok { + continue + } switch m["type"] { case "text": if textValue, ok := m["text"]; ok { @@ -90,10 +93,16 @@ func CountTokenMessages(messages []*model.Message, model string) int { case "image_url": imageURL, ok := m["image_url"].(map[string]any) if ok { - url := imageURL["url"].(string) + url, ok := imageURL["url"].(string) + if !ok { + continue + } detail := "" if imageURL["detail"] != nil { - detail = imageURL["detail"].(string) + detail, ok = imageURL["detail"].(string) + if !ok { + continue + } } imageTokens, err := countImageTokens(url, detail, model) if err != nil { @@ -217,11 +226,7 @@ func CountTokenText(text string, model string) int { if strings.HasPrefix(model, "tts") { return utf8.RuneCountInString(text) } - if strings.HasPrefix(model, "sambert-") { - return len(text) - } - tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text) + return getTokenNum(getTokenEncoder(model), text) } func CountToken(text string) int {