Skip to content

Commit

Permalink
feat: reduce the number of stream memory copies (#5402)
Browse files Browse the repository at this point in the history
* fix: generalize the openai data prefix

* feat: reduce the number of stream memory copies

* fix: lint

* feat: reduce the number of stream memory copies

* fix: lint

* fix: reset buf when usage

* chore: mv event to render

* fix: ignore samgrep error
  • Loading branch information
zijiren233 authored Feb 25, 2025
1 parent a855c20 commit 40ec466
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 139 deletions.
67 changes: 0 additions & 67 deletions service/aiproxy/common/custom-event.go

This file was deleted.

51 changes: 51 additions & 0 deletions service/aiproxy/common/render/event.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package render

import (
"net/http"

"github.com/labring/sealos/service/aiproxy/common/conv"
)

var (
contentType = []string{"text/event-stream"}
noCache = []string{"no-cache"}
)

type OpenAISSE struct {
Data string
}

const (
nn = "\n\n"
data = "data: "
)

var (
nnBytes = conv.StringToBytes(nn)
dataBytes = conv.StringToBytes(data)
)

func (r OpenAISSE) Render(w http.ResponseWriter) error {
r.WriteContentType(w)

for _, bytes := range [][]byte{
dataBytes,
conv.StringToBytes(r.Data),
nnBytes,
} {
// nosemgrep: go.lang.security.audit.xss.no-direct-write-to-responsewriter.no-direct-write-to-responsewriter
if _, err := w.Write(bytes); err != nil {
return err
}
}
return nil
}

func (r OpenAISSE) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType

if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}
9 changes: 3 additions & 6 deletions service/aiproxy/common/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package render
import (
"errors"
"fmt"
"strings"

"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/conv"
)

Expand All @@ -20,9 +18,7 @@ func StringData(c *gin.Context, str string) {
if c.IsAborted() {
return
}
str = strings.TrimPrefix(str, "data:")
// str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + strings.TrimSpace(str)})
c.Render(-1, OpenAISSE{Data: str})
c.Writer.Flush()
}

Expand All @@ -37,7 +33,8 @@ func ObjectData(c *gin.Context, object any) error {
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, conv.BytesToString(jsonData))
c.Render(-1, OpenAISSE{Data: conv.BytesToString(jsonData)})
c.Writer.Flush()
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/aws/claude/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
c.Stream(func(_ io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
render.StringData(c, "[DONE]")
render.Done(c)
return false
}

Expand Down
3 changes: 1 addition & 2 deletions service/aiproxy/relay/adaptor/aws/llama3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/random"
"github.com/labring/sealos/service/aiproxy/common/render"
"github.com/labring/sealos/service/aiproxy/middleware"
Expand Down Expand Up @@ -209,7 +208,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
c.Stream(func(_ io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
render.Done(c)
return false
}

Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/coze/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ func (a *Adaptor) DoRequest(_ *meta.Meta, _ *gin.Context, req *http.Request) (*h
func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (usage *relaymodel.Usage, err *relaymodel.ErrorWithStatusCode) {
var responseText *string
if utils.IsStreamResponse(resp) {
err, responseText = StreamHandler(c, resp)
err, responseText = StreamHandler(meta, c, resp)
} else {
err, responseText = Handler(c, resp, meta.InputTokens, meta.ActualModel)
err, responseText = Handler(meta, c, resp)
}
if responseText != nil {
usage = openai.ResponseText2Usage(*responseText, meta.ActualModel, meta.InputTokens)
Expand Down
10 changes: 5 additions & 5 deletions service/aiproxy/relay/adaptor/coze/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/labring/sealos/service/aiproxy/relay/adaptor/coze/constant/messagetype"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/constant"
"github.com/labring/sealos/service/aiproxy/relay/meta"
"github.com/labring/sealos/service/aiproxy/relay/model"
)

Expand Down Expand Up @@ -85,7 +86,7 @@ func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse {
return &fullTextResponse
}

func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
defer resp.Body.Close()

log := middleware.GetLogger(c)
Expand All @@ -96,7 +97,6 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
scanner.Split(bufio.ScanLines)

common.SetEventStreamHeaders(c)
var modelName string

for scanner.Scan() {
data := scanner.Bytes()
Expand Down Expand Up @@ -124,7 +124,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Model = meta.OriginModel
response.Created = createdTime

_ = render.ObjectData(c, response)
Expand All @@ -139,7 +139,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &responseText
}

func Handler(c *gin.Context, resp *http.Response, _ int, modelName string) (*model.ErrorWithStatusCode, *string) {
func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
defer resp.Body.Close()

log := middleware.GetLogger(c)
Expand All @@ -159,7 +159,7 @@ func Handler(c *gin.Context, resp *http.Response, _ int, modelName string) (*mod
}, nil
}
fullTextResponse := ResponseCoze2OpenAI(&cozeResponse)
fullTextResponse.Model = modelName
fullTextResponse.Model = meta.OriginModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
Expand Down
Loading

0 comments on commit 40ec466

Please sign in to comment.