diff --git a/service/aiproxy/common/custom-event.go b/service/aiproxy/common/custom-event.go index a7a76219fb92..a3d225f2f499 100644 --- a/service/aiproxy/common/custom-event.go +++ b/service/aiproxy/common/custom-event.go @@ -5,9 +5,7 @@ package common import ( - "io" "net/http" - "strings" "github.com/labring/sealos/service/aiproxy/common/conv" ) @@ -21,43 +19,30 @@ var ( noCache = []string{"no-cache"} ) -var dataReplacer = strings.NewReplacer( - "\n", "\ndata:", - "\r", "\\r") - -type CustomEvent struct { - Data string - Event string - ID string - Retry uint +type OpenAISSE struct { + Data string } -func encode(writer io.Writer, event CustomEvent) error { - return writeData(writer, event.Data) -} +const ( + nn = "\n\n" + data = "data: " +) -const nn = "\n\n" +var ( + nnBytes = conv.StringToBytes(nn) + dataBytes = conv.StringToBytes(data) +) -var nnBytes = conv.StringToBytes(nn) +func (r OpenAISSE) Render(w http.ResponseWriter) error { + r.WriteContentType(w) -func writeData(w io.Writer, data string) error { - _, err := dataReplacer.WriteString(w, data) - if err != nil { - return err - } - if strings.HasPrefix(data, "data") { - _, err := w.Write(nnBytes) - return err - } + w.Write(dataBytes) + w.Write(conv.StringToBytes(r.Data)) + w.Write(nnBytes) return nil } -func (r CustomEvent) Render(w http.ResponseWriter) error { - r.WriteContentType(w) - return encode(w, r) -} - -func (r CustomEvent) WriteContentType(w http.ResponseWriter) { +func (r OpenAISSE) WriteContentType(w http.ResponseWriter) { header := w.Header() header["Content-Type"] = contentType diff --git a/service/aiproxy/common/render/render.go b/service/aiproxy/common/render/render.go index ff4e8f07b773..8daac428ffcc 100644 --- a/service/aiproxy/common/render/render.go +++ b/service/aiproxy/common/render/render.go @@ -3,7 +3,6 @@ package render import ( "errors" "fmt" - "strings" "github.com/gin-gonic/gin" json "github.com/json-iterator/go" @@ -20,9 +19,7 @@ func StringData(c *gin.Context, str string) { if c.IsAborted() { return } - str = strings.TrimPrefix(str, "data:") - // str = strings.TrimSuffix(str, "\r") - c.Render(-1, common.CustomEvent{Data: "data: " + strings.TrimSpace(str)}) + c.Render(-1, common.OpenAISSE{Data: str}) c.Writer.Flush() } @@ -37,7 +34,8 @@ func ObjectData(c *gin.Context, object any) error { if err != nil { return fmt.Errorf("error marshalling object: %w", err) } - StringData(c, conv.BytesToString(jsonData)) + c.Render(-1, common.OpenAISSE{Data: conv.BytesToString(jsonData)}) + c.Writer.Flush() return nil } diff --git a/service/aiproxy/relay/adaptor/aws/claude/main.go b/service/aiproxy/relay/adaptor/aws/claude/main.go index 7a9e19f1d611..55b65b60db22 100644 --- a/service/aiproxy/relay/adaptor/aws/claude/main.go +++ b/service/aiproxy/relay/adaptor/aws/claude/main.go @@ -202,7 +202,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus c.Stream(func(_ io.Writer) bool { event, ok := <-stream.Events() if !ok { - render.StringData(c, "[DONE]") + render.Done(c) return false } diff --git a/service/aiproxy/relay/adaptor/aws/llama3/main.go b/service/aiproxy/relay/adaptor/aws/llama3/main.go index d353096178fe..4ba3147cf368 100644 --- a/service/aiproxy/relay/adaptor/aws/llama3/main.go +++ b/service/aiproxy/relay/adaptor/aws/llama3/main.go @@ -14,7 +14,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/random" "github.com/labring/sealos/service/aiproxy/common/render" "github.com/labring/sealos/service/aiproxy/middleware" @@ -209,7 +208,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus c.Stream(func(_ io.Writer) bool { event, ok := <-stream.Events() if !ok { - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + render.Done(c) return false } diff --git a/service/aiproxy/relay/adaptor/openai/main.go b/service/aiproxy/relay/adaptor/openai/main.go index ebb5053c4ed5..93cbc2c2fc27 100644 --- a/service/aiproxy/relay/adaptor/openai/main.go +++ b/service/aiproxy/relay/adaptor/openai/main.go @@ -2,9 +2,10 @@ package openai import ( "bufio" + "bytes" "io" "net/http" - "strings" + "slices" "github.com/gin-gonic/gin" json "github.com/json-iterator/go" @@ -24,6 +25,11 @@ const ( DataPrefixLength = len(DataPrefix) ) +var ( + DataPrefixBytes = conv.StringToBytes(DataPrefix) + DoneBytes = conv.StringToBytes(Done) +) + var stdjson = json.ConfigCompatibleWithStandardLibrary type UsageAndChoicesResponse struct { @@ -51,22 +57,22 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model } for scanner.Scan() { - data := scanner.Text() + data := scanner.Bytes() if len(data) < DataPrefixLength { // ignore blank line or wrong format continue } - if data[:DataPrefixLength] != DataPrefix { + if !slices.Equal(data[:DataPrefixLength], DataPrefixBytes) { continue } - data = strings.TrimSpace(data[DataPrefixLength:]) - - if strings.HasPrefix(data, Done) { + data = bytes.TrimSpace(data[DataPrefixLength:]) + if slices.Equal(data, DoneBytes) { break } + switch meta.Mode { case relaymode.ChatCompletions: var streamResponse UsageAndChoicesResponse - err := json.Unmarshal(conv.StringToBytes(data), &streamResponse) + err := json.Unmarshal(data, &streamResponse) if err != nil { log.Error("error unmarshalling stream response: " + err.Error()) continue @@ -81,7 +87,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model } } respMap := make(map[string]any) - err = json.Unmarshal(conv.StringToBytes(data), &respMap) + err = json.Unmarshal(data, &respMap) if err != nil { log.Error("error unmarshalling stream response: " + err.Error()) continue @@ -98,7 +104,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model _ = render.ObjectData(c, respMap) case relaymode.Completions: var streamResponse CompletionsStreamResponse - err := json.Unmarshal(conv.StringToBytes(data), &streamResponse) + err := json.Unmarshal(data, &streamResponse) if err != nil { log.Error("error unmarshalling stream response: " + err.Error()) continue @@ -109,7 +115,16 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model if streamResponse.Usage != nil { usage = streamResponse.Usage } - render.StringData(c, data) + respMap := make(map[string]any) + err = json.Unmarshal(data, &respMap) + if err != nil { + log.Error("error unmarshalling stream response: " + err.Error()) + continue + } + if _, ok := respMap["model"]; ok && meta.OriginModel != "" { + respMap["model"] = meta.OriginModel + } + _ = render.ObjectData(c, respMap) } } @@ -134,40 +149,41 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model // renderCallback maybe reuse data, so don't modify data func StreamSplitThink(data map[string]any, thinkSplitter *splitter.Splitter, renderCallback func(data map[string]any)) { choices, ok := data["choices"].([]any) + // only support one choice + if !ok || len(choices) != 1 { + renderCallback(data) + return + } + choice := choices[0] + choiceMap, ok := choice.(map[string]any) if !ok { + renderCallback(data) return } - for _, choice := range choices { - choiceMap, ok := choice.(map[string]any) - if !ok { - renderCallback(data) - continue - } - delta, ok := choiceMap["delta"].(map[string]any) - if !ok { - renderCallback(data) - continue - } - content, ok := delta["content"].(string) - if !ok { - renderCallback(data) - continue - } - think, remaining := thinkSplitter.Process(conv.StringToBytes(content)) - if len(think) == 0 && len(remaining) == 0 { - renderCallback(data) - continue - } - if len(think) > 0 { - delta["content"] = "" - delta["reasoning_content"] = conv.BytesToString(think) - renderCallback(data) - } - if len(remaining) > 0 { - delta["content"] = conv.BytesToString(remaining) - delta["reasoning_content"] = "" - renderCallback(data) - } + delta, ok := choiceMap["delta"].(map[string]any) + if !ok { + renderCallback(data) + return + } + content, ok := delta["content"].(string) + if !ok { + renderCallback(data) + return + } + think, remaining := thinkSplitter.Process(conv.StringToBytes(content)) + if len(think) == 0 && len(remaining) == 0 { + renderCallback(data) + return + } + if len(think) > 0 { + delta["content"] = "" + delta["reasoning_content"] = conv.BytesToString(think) + renderCallback(data) + } + if len(remaining) > 0 { + delta["content"] = conv.BytesToString(remaining) + delete(delta, "reasoning_content") + renderCallback(data) } } @@ -216,6 +232,10 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { + if choice.Text != "" { + completionTokens += CountTokenText(choice.Text, meta.ActualModel) + continue + } completionTokens += CountTokenText(choice.Message.StringContent(), meta.ActualModel) } textResponse.Usage = model.Usage{ diff --git a/service/aiproxy/relay/adaptor/openai/model.go b/service/aiproxy/relay/adaptor/openai/model.go index ad0d579bf56e..6ba7de31f6d9 100644 --- a/service/aiproxy/relay/adaptor/openai/model.go +++ b/service/aiproxy/relay/adaptor/openai/model.go @@ -83,6 +83,7 @@ type TextResponseChoice struct { FinishReason string `json:"finish_reason"` Message model.Message `json:"message"` Index int `json:"index"` + Text string `json:"text"` } type TextResponse struct {