diff --git a/embed_compat.go b/embed_compat.go index 2339730..75ee528 100644 --- a/embed_compat.go +++ b/embed_compat.go @@ -11,11 +11,9 @@ const ( func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { // Mistral embeddings are normalized, see section "Distance Measures" on // https://docs.mistral.ai/guides/embeddings/. - normalized := true - // The Mistral API docs don't mention the `encoding_format` as optional, // but it seems to be, just like OpenAI. So we reuse the OpenAI function. - return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized) + return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLMistral, apiKey, embeddingModelMistral).WithNormalized(true)) } const baseURLJina = "https://api.jina.ai/v1" @@ -36,7 +34,7 @@ const ( // NewEmbeddingFuncJina returns a function that creates embeddings for a text // using the Jina API. func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil) + return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLJina, apiKey, string(model))) } const baseURLMixedbread = "https://api.mixedbread.ai" @@ -69,7 +67,7 @@ const ( // NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text // using the mixedbread.ai API. func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil) + return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLMixedbread, apiKey, string(model))) } const baseURLLocalAI = "http://localhost:8080/v1" @@ -84,7 +82,7 @@ const baseURLLocalAI = "http://localhost:8080/v1" // But other embedding models are supported as well. See the LocalAI documentation // for details. func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil) + return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLLocalAI, "", model)) } const ( @@ -99,5 +97,5 @@ func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion if apiVersion == "" { apiVersion = azureDefaultAPIVersion } - return NewEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, WithOpenAICompatHeaders(map[string]string{"api-key": apiKey}), WithOpenAICompatQueryParams(map[string]string{"api-version": apiVersion})) + return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(deploymentURL, apiKey, model).WithHeaders(map[string]string{"api-key": apiKey}).WithQueryParams(map[string]string{"api-version": apiVersion})) } diff --git a/embed_openai.go b/embed_openai.go index 75456ca..df4f2ea 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -43,8 +43,7 @@ func NewEmbeddingFuncDefault() EmbeddingFunc { // using the OpenAI API. func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { // OpenAI embeddings are normalized - normalized := true - return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) + return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(BaseURLOpenAI, apiKey, string(model)).WithNormalized(true)) } // NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text @@ -61,17 +60,16 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding // model are already normalized, as is the case for OpenAI's and Mistral's models. // The flag is optional. If it's nil, it will be autodetected on the first request // (which bears a small risk that the vector just happens to have a length of 1). -func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, opts ...OpenAICompatOption) EmbeddingFunc { +func NewEmbeddingFuncOpenAICompat(config *openAICompatConfig) EmbeddingFunc { + if config == nil { + panic("config must not be nil") + } + // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, // and it might have to be a long timeout, depending on the text length. client := &http.Client{} - cfg := DefaultOpenAICompatOptions() - for _, opt := range opts { - opt(cfg) - } - var checkedNormalized bool checkNormalized := sync.Once{} @@ -79,13 +77,13 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ "input": text, - "model": model, + "model": config.model, }) if err != nil { return nil, fmt.Errorf("couldn't marshal request body: %w", err) } - fullURL, err := url.JoinPath(baseURL, cfg.EmbeddingsEndpoint) + fullURL, err := url.JoinPath(config.baseURL, config.embeddingsEndpoint) if err != nil { return nil, fmt.Errorf("couldn't join base URL and endpoint: %w", err) } @@ -97,16 +95,16 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo return nil, fmt.Errorf("couldn't create request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Authorization", "Bearer "+config.apiKey) // Add headers - for k, v := range cfg.Headers { + for k, v := range config.headers { req.Header.Add(k, v) } // Add query parameters q := req.URL.Query() - for k, v := range cfg.QueryParams { + for k, v := range config.queryParams { q.Add(k, v) } req.URL.RawQuery = q.Encode() @@ -140,8 +138,8 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo } v := embeddingResponse.Data[0].Embedding - if normalized != nil { - if *normalized { + if config.normalized != nil { + if *config.normalized { return v, nil } return normalizeVector(v), nil @@ -161,34 +159,44 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo } } -type OpenAICompatOptions struct { - EmbeddingsEndpoint string - Headers map[string]string - QueryParams map[string]string +type openAICompatConfig struct { + baseURL string + apiKey string + model string + + // Optional + normalized *bool + embeddingsEndpoint string + headers map[string]string + queryParams map[string]string } -type OpenAICompatOption func(*OpenAICompatOptions) +func NewOpenAICompatConfig(baseURL, apiKey, model string) *openAICompatConfig { + return &openAICompatConfig{ + baseURL: baseURL, + apiKey: apiKey, + model: model, -func WithOpenAICompatEmbeddingsEndpointOverride(endpoint string) OpenAICompatOption { - return func(o *OpenAICompatOptions) { - o.EmbeddingsEndpoint = endpoint + embeddingsEndpoint: "/embeddings", } } -func WithOpenAICompatHeaders(headers map[string]string) OpenAICompatOption { - return func(o *OpenAICompatOptions) { - o.Headers = headers - } +func (c *openAICompatConfig) WithEmbeddingsEndpoint(endpoint string) *openAICompatConfig { + c.embeddingsEndpoint = endpoint + return c } -func WithOpenAICompatQueryParams(queryParams map[string]string) OpenAICompatOption { - return func(o *OpenAICompatOptions) { - o.QueryParams = queryParams - } +func (c *openAICompatConfig) WithHeaders(headers map[string]string) *openAICompatConfig { + c.headers = headers + return c } -func DefaultOpenAICompatOptions() *OpenAICompatOptions { - return &OpenAICompatOptions{ - EmbeddingsEndpoint: "/embeddings", - } +func (c *openAICompatConfig) WithQueryParams(queryParams map[string]string) *openAICompatConfig { + c.queryParams = queryParams + return c +} + +func (c *openAICompatConfig) WithNormalized(normalized bool) *openAICompatConfig { + c.normalized = &normalized + return c } diff --git a/embed_openai_test.go b/embed_openai_test.go index 5243b81..e5bcfa4 100644 --- a/embed_openai_test.go +++ b/embed_openai_test.go @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { defer ts.Close() baseURL := ts.URL + baseURLSuffix - f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil) + f := chromem.NewEmbeddingFuncOpenAICompat(chromem.NewOpenAICompatConfig(baseURL, apiKey, model)) res, err := f(context.Background(), input) if err != nil { t.Fatal("expected nil, got", err)