Skip to content

Commit

Permalink
feat: support ollama multi-text embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Sep 12, 2024
1 parent 1675679 commit 0cbf8e0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
18 changes: 16 additions & 2 deletions relay/channel/ollama/dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@ type OllamaRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}

type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}

type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Prompt any `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
}

type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding []float64 `json:"embedding,omitempty"`
}
15 changes: 12 additions & 3 deletions relay/channel/ollama/relay-ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)

func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Expand Down Expand Up @@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {

func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
Model: request.Model,
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
}
}

Expand All @@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding,
Expand Down

0 comments on commit 0cbf8e0

Please sign in to comment.