Skip to content

Commit

Permalink
feat: support cohere first response time
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Jun 28, 2024
1 parent d767ae0 commit a7e3168
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion relay/channel/cohere/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
Expand Down
13 changes: 10 additions & 3 deletions relay/channel/cohere/relay-cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)

func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
Expand Down Expand Up @@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}

func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
Expand Down Expand Up @@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
stopChan <- true
}()
service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
Expand All @@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName
openaiResp.Model = info.UpstreamModelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
Expand Down Expand Up @@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
}
Expand Down

0 comments on commit a7e3168

Please sign in to comment.