From 015469581da7828a2fe93388487d0039454202b0 Mon Sep 17 00:00:00 2001 From: SLKun Date: Tue, 6 Aug 2024 23:43:20 +0800 Subject: [PATCH] feat: update Ollama embedding API to latest version with multi-text embedding support (#1715) --- relay/adaptor/ollama/adaptor.go | 2 +- relay/adaptor/ollama/main.go | 25 +++++++++++++++++-------- relay/adaptor/ollama/model.go | 12 ++++++++---- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go index 66702c5dde..ad1f898350 100644 --- a/relay/adaptor/ollama/adaptor.go +++ b/relay/adaptor/ollama/adaptor.go @@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://github.com/ollama/ollama/blob/main/docs/api.md fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) if meta.Mode == relaymode.Embeddings { - fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) + fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL) } return fullRequestURL, nil } diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index 936a7e144e..6a1d334d1a 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -157,8 +157,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ - Model: request.Model, - Prompt: strings.Join(request.ParseInput(), " "), + Model: request.Model, + Input: request.ParseInput(), + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, } } @@ -201,15 +208,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", Data: make([]openai.EmbeddingResponseItem, 0, 1), - Model: "text-embedding-v1", + Model: response.Model, Usage: model.Usage{TotalTokens: 0}, } - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ - Object: `embedding`, - Index: 0, - Embedding: response.Embedding, - }) + for i, embedding := range response.Embeddings { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: i, + Embedding: embedding, + }) + } return &openAIEmbeddingResponse } diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go index 8baf56a040..29430e1c7c 100644 --- a/relay/adaptor/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -37,11 +37,15 @@ type ChatResponse struct { } type EmbeddingRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Input []string `json:"input"` + // Truncate bool `json:"truncate,omitempty"` + Options *Options `json:"options,omitempty"` + // KeepAlive string `json:"keep_alive,omitempty"` } type EmbeddingResponse struct { - Error string `json:"error,omitempty"` - Embedding []float64 `json:"embedding,omitempty"` + Error string `json:"error,omitempty"` + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` }