From cb7388935377848bbd43c45c65ed48e0e2b29023 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 13 Sep 2024 03:17:04 +0800 Subject: [PATCH] feat: support o1 channel test --- controller/channel-test.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 95c4a60aa..ff663864c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,7 @@ import ( "one-api/relay/constant" "one-api/service" "strconv" + "strings" "sync" "time" @@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } - request := buildTestRequest() - request.Model = testModel + request := buildTestRequest(testModel) meta.UpstreamModelName = testModel common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) @@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return nil, nil } -func buildTestRequest() *dto.GeneralOpenAIRequest { +func buildTestRequest(model string) *dto.GeneralOpenAIRequest { testRequest := &dto.GeneralOpenAIRequest{ - Model: "", // this will be set later - MaxTokens: 1, - Stream: false, + Model: "", // this will be set later + Stream: false, + } + if strings.HasPrefix(model, "o1-") { + testRequest.MaxCompletionTokens = 1 + } else { + testRequest.MaxTokens = 1 } content, _ := json.Marshal("hi") testMessage := dto.Message{ Role: "user", Content: content, } + testRequest.Model = model testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } @@ -226,26 +231,22 @@ func testAllChannels(notify bool) error { tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() - ban := false - if milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - ban = true - } + shouldBanChannel := false // request error disables the channel if openaiWithStatusErr != nil { oaiErr := openaiWithStatusErr.Error err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) - ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) + shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) } - // parse *int to bool - if !channel.GetAutoBan() { - ban = false + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + shouldBanChannel = true } // disable channel - if ban && isChannelEnabled { + if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { service.DisableChannel(channel.Id, channel.Name, err.Error()) }