Skip to content

Commit

Permalink
feat: reduce the number of stream memory copies
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Feb 24, 2025
1 parent 6f1b92a commit 646f592
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 80 deletions.
47 changes: 16 additions & 31 deletions service/aiproxy/common/custom-event.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
package common

import (
"io"
"net/http"
"strings"

"github.com/labring/sealos/service/aiproxy/common/conv"
)
Expand All @@ -21,43 +19,30 @@ var (
noCache = []string{"no-cache"}
)

var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")

type CustomEvent struct {
Data string
Event string
ID string
Retry uint
type OpenAISSE struct {
Data string
}

func encode(writer io.Writer, event CustomEvent) error {
return writeData(writer, event.Data)
}
const (
nn = "\n\n"
data = "data: "
)

const nn = "\n\n"
var (
nnBytes = conv.StringToBytes(nn)
dataBytes = conv.StringToBytes(data)
)

var nnBytes = conv.StringToBytes(nn)
func (r OpenAISSE) Render(w http.ResponseWriter) error {
r.WriteContentType(w)

func writeData(w io.Writer, data string) error {
_, err := dataReplacer.WriteString(w, data)
if err != nil {
return err
}
if strings.HasPrefix(data, "data") {
_, err := w.Write(nnBytes)
return err
}
w.Write(dataBytes)

Check failure on line 39 in service/aiproxy/common/custom-event.go

View workflow job for this annotation

GitHub Actions / golangci-lint (./service/aiproxy)

Error return value of `w.Write` is not checked (errcheck)
w.Write(conv.StringToBytes(r.Data))

Check failure on line 40 in service/aiproxy/common/custom-event.go

View workflow job for this annotation

GitHub Actions / golangci-lint (./service/aiproxy)

Error return value of `w.Write` is not checked (errcheck)
w.Write(nnBytes)

Check failure on line 41 in service/aiproxy/common/custom-event.go

View workflow job for this annotation

GitHub Actions / golangci-lint (./service/aiproxy)

Error return value of `w.Write` is not checked (errcheck)
return nil
}

func (r CustomEvent) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return encode(w, r)
}

func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
func (r OpenAISSE) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType

Expand Down
8 changes: 3 additions & 5 deletions service/aiproxy/common/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package render
import (
"errors"
"fmt"
"strings"

"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
Expand All @@ -20,9 +19,7 @@ func StringData(c *gin.Context, str string) {
if c.IsAborted() {
return
}
str = strings.TrimPrefix(str, "data:")
// str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + strings.TrimSpace(str)})
c.Render(-1, common.OpenAISSE{Data: str})
c.Writer.Flush()
}

Expand All @@ -37,7 +34,8 @@ func ObjectData(c *gin.Context, object any) error {
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, conv.BytesToString(jsonData))
c.Render(-1, common.OpenAISSE{Data: conv.BytesToString(jsonData)})
c.Writer.Flush()
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/aws/claude/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
c.Stream(func(_ io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
render.StringData(c, "[DONE]")
render.Done(c)
return false
}

Expand Down
3 changes: 1 addition & 2 deletions service/aiproxy/relay/adaptor/aws/llama3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/random"
"github.com/labring/sealos/service/aiproxy/common/render"
"github.com/labring/sealos/service/aiproxy/middleware"
Expand Down Expand Up @@ -209,7 +208,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
c.Stream(func(_ io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
render.Done(c)
return false
}

Expand Down
102 changes: 61 additions & 41 deletions service/aiproxy/relay/adaptor/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package openai

import (
"bufio"
"bytes"
"io"
"net/http"
"strings"
"slices"

"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
Expand All @@ -24,6 +25,11 @@ const (
DataPrefixLength = len(DataPrefix)
)

var (
DataPrefixBytes = conv.StringToBytes(DataPrefix)
DoneBytes = conv.StringToBytes(Done)
)

var stdjson = json.ConfigCompatibleWithStandardLibrary

type UsageAndChoicesResponse struct {
Expand Down Expand Up @@ -51,22 +57,22 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
}

for scanner.Scan() {
data := scanner.Text()
data := scanner.Bytes()
if len(data) < DataPrefixLength { // ignore blank line or wrong format
continue
}
if data[:DataPrefixLength] != DataPrefix {
if !slices.Equal(data[:DataPrefixLength], DataPrefixBytes) {
continue
}
data = strings.TrimSpace(data[DataPrefixLength:])

if strings.HasPrefix(data, Done) {
data = bytes.TrimSpace(data[DataPrefixLength:])
if slices.Equal(data, DoneBytes) {
break
}

switch meta.Mode {
case relaymode.ChatCompletions:
var streamResponse UsageAndChoicesResponse
err := json.Unmarshal(conv.StringToBytes(data), &streamResponse)
err := json.Unmarshal(data, &streamResponse)
if err != nil {
log.Error("error unmarshalling stream response: " + err.Error())
continue
Expand All @@ -81,7 +87,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
}
}
respMap := make(map[string]any)
err = json.Unmarshal(conv.StringToBytes(data), &respMap)
err = json.Unmarshal(data, &respMap)
if err != nil {
log.Error("error unmarshalling stream response: " + err.Error())
continue
Expand All @@ -98,7 +104,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
_ = render.ObjectData(c, respMap)
case relaymode.Completions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal(conv.StringToBytes(data), &streamResponse)
err := json.Unmarshal(data, &streamResponse)
if err != nil {
log.Error("error unmarshalling stream response: " + err.Error())
continue
Expand All @@ -109,7 +115,16 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
render.StringData(c, data)
respMap := make(map[string]any)
err = json.Unmarshal(data, &respMap)
if err != nil {
log.Error("error unmarshalling stream response: " + err.Error())
continue
}
if _, ok := respMap["model"]; ok && meta.OriginModel != "" {
respMap["model"] = meta.OriginModel
}
_ = render.ObjectData(c, respMap)
}
}

Expand All @@ -134,40 +149,41 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
// renderCallback maybe reuse data, so don't modify data
func StreamSplitThink(data map[string]any, thinkSplitter *splitter.Splitter, renderCallback func(data map[string]any)) {
choices, ok := data["choices"].([]any)
// only support one choice
if !ok || len(choices) != 1 {
renderCallback(data)
return
}
choice := choices[0]
choiceMap, ok := choice.(map[string]any)
if !ok {
renderCallback(data)
return
}
for _, choice := range choices {
choiceMap, ok := choice.(map[string]any)
if !ok {
renderCallback(data)
continue
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
renderCallback(data)
continue
}
content, ok := delta["content"].(string)
if !ok {
renderCallback(data)
continue
}
think, remaining := thinkSplitter.Process(conv.StringToBytes(content))
if len(think) == 0 && len(remaining) == 0 {
renderCallback(data)
continue
}
if len(think) > 0 {
delta["content"] = ""
delta["reasoning_content"] = conv.BytesToString(think)
renderCallback(data)
}
if len(remaining) > 0 {
delta["content"] = conv.BytesToString(remaining)
delta["reasoning_content"] = ""
renderCallback(data)
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
renderCallback(data)
return
}
content, ok := delta["content"].(string)
if !ok {
renderCallback(data)
return
}
think, remaining := thinkSplitter.Process(conv.StringToBytes(content))
if len(think) == 0 && len(remaining) == 0 {
renderCallback(data)
return
}
if len(think) > 0 {
delta["content"] = ""
delta["reasoning_content"] = conv.BytesToString(think)
renderCallback(data)
}
if len(remaining) > 0 {
delta["content"] = conv.BytesToString(remaining)
delete(delta, "reasoning_content")
renderCallback(data)
}
}

Expand Down Expand Up @@ -216,6 +232,10 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage
if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range textResponse.Choices {
if choice.Text != "" {
completionTokens += CountTokenText(choice.Text, meta.ActualModel)
continue
}
completionTokens += CountTokenText(choice.Message.StringContent(), meta.ActualModel)
}
textResponse.Usage = model.Usage{
Expand Down
1 change: 1 addition & 0 deletions service/aiproxy/relay/adaptor/openai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type TextResponseChoice struct {
FinishReason string `json:"finish_reason"`
Message model.Message `json:"message"`
Index int `json:"index"`
Text string `json:"text"`
}

type TextResponse struct {
Expand Down

0 comments on commit 646f592

Please sign in to comment.