Skip to content

Commit

Permalink
feat: reduce the number of stream memory copies
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Feb 24, 2025
1 parent c89e0f5 commit 60b7b65
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
47 changes: 39 additions & 8 deletions service/aiproxy/relay/adaptor/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"io"
"net/http"
"slices"
"strings"
"sync"

"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
Expand Down Expand Up @@ -37,14 +39,38 @@ type UsageAndChoicesResponse struct {
Choices []*ChatCompletionsStreamResponseChoice
}

const scannerBufferSize = 2 * bufio.MaxScanTokenSize

var scannerBufferPool = sync.Pool{
New: func() any {
buf := make([]byte, scannerBufferSize, scannerBufferSize)

Check failure on line 46 in service/aiproxy/relay/adaptor/openai/main.go

View workflow job for this annotation

GitHub Actions / golangci-lint (./service/aiproxy)

S1019: should use make([]byte, scannerBufferSize) instead (gosimple)
return &buf
},
}

//nolint:forcetypeassert
func getScannerBuffer() *[]byte {
return scannerBufferPool.Get().(*[]byte)
}

func putScannerBuffer(buf *[]byte) {
if cap(*buf) != scannerBufferSize {
return
}
scannerBufferPool.Put(buf)
}

func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) {
defer resp.Body.Close()

log := middleware.GetLogger(c)

responseText := ""
responseText := strings.Builder{}

scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
buf := getScannerBuffer()
defer putScannerBuffer(buf)
scanner.Buffer(*buf, cap(*buf))

var usage *model.Usage

Expand Down Expand Up @@ -81,7 +107,9 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
usage = streamResponse.Usage
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.StringContent()
if usage == nil {
responseText.WriteString(choice.Delta.StringContent())
}
if choice.Delta.ReasoningContent != "" {
hasReasoningContent = true
}
Expand Down Expand Up @@ -109,11 +137,12 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
log.Error("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
} else {
for _, choice := range streamResponse.Choices {
responseText.WriteString(choice.Text)
}
}
respMap := make(map[string]any)
err = json.Unmarshal(data, &respMap)
Expand All @@ -134,8 +163,8 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model

render.Done(c)

if usage == nil || (usage.TotalTokens == 0 && responseText != "") {
usage = ResponseText2Usage(responseText, meta.ActualModel, meta.InputTokens)
if usage == nil || (usage.TotalTokens == 0 && responseText.Len() > 0) {
usage = ResponseText2Usage(responseText.String(), meta.ActualModel, meta.InputTokens)
}

if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens
Expand Down Expand Up @@ -220,11 +249,13 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage
if err != nil {
return nil, ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}

var textResponse SlimTextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return nil, ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}

if textResponse.Error.Type != "" {
return nil, ErrorWrapperWithMessage(textResponse.Error.Message, textResponse.Error.Code, http.StatusBadRequest)
}
Expand Down
21 changes: 13 additions & 8 deletions service/aiproxy/relay/adaptor/openai/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ func CountTokenMessages(messages []*model.Message, model string) int {
tokenNum += getTokenNum(tokenEncoder, v)
case []any:
for _, it := range v {
m := it.(map[string]any)
m, ok := it.(map[string]any)
if !ok {
continue
}
switch m["type"] {
case "text":
if textValue, ok := m["text"]; ok {
Expand All @@ -90,10 +93,16 @@ func CountTokenMessages(messages []*model.Message, model string) int {
case "image_url":
imageURL, ok := m["image_url"].(map[string]any)
if ok {
url := imageURL["url"].(string)
url, ok := imageURL["url"].(string)
if !ok {
continue
}
detail := ""
if imageURL["detail"] != nil {
detail = imageURL["detail"].(string)
detail, ok = imageURL["detail"].(string)
if !ok {
continue
}
}
imageTokens, err := countImageTokens(url, detail, model)
if err != nil {
Expand Down Expand Up @@ -217,11 +226,7 @@ func CountTokenText(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
}
if strings.HasPrefix(model, "sambert-") {
return len(text)
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
return getTokenNum(getTokenEncoder(model), text)
}

func CountToken(text string) int {
Expand Down

0 comments on commit 60b7b65

Please sign in to comment.