Skip to content

Commit

Permalink
change: switch to builder pattern for options
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed Aug 13, 2024
1 parent 3f42010 commit d838df0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 43 deletions.
12 changes: 5 additions & 7 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ const (
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// Mistral embeddings are normalized, see section "Distance Measures" on
// https://docs.mistral.ai/guides/embeddings/.
normalized := true

// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLMistral, apiKey, embeddingModelMistral).WithNormalized(true))
}

const baseURLJina = "https://api.jina.ai/v1"
Expand All @@ -36,7 +34,7 @@ const (
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLJina, apiKey, string(model)))
}

const baseURLMixedbread = "https://api.mixedbread.ai"
Expand Down Expand Up @@ -69,7 +67,7 @@ const (
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLMixedbread, apiKey, string(model)))
}

const baseURLLocalAI = "http://localhost:8080/v1"
Expand All @@ -84,7 +82,7 @@ const baseURLLocalAI = "http://localhost:8080/v1"
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(baseURLLocalAI, "", model))
}

const (
Expand All @@ -99,5 +97,5 @@ func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion
if apiVersion == "" {
apiVersion = azureDefaultAPIVersion
}
return NewEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, WithOpenAICompatHeaders(map[string]string{"api-key": apiKey}), WithOpenAICompatQueryParams(map[string]string{"api-version": apiVersion}))
return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(deploymentURL, apiKey, model).WithHeaders(map[string]string{"api-key": apiKey}).WithQueryParams(map[string]string{"api-version": apiVersion}))
}
78 changes: 43 additions & 35 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ func NewEmbeddingFuncDefault() EmbeddingFunc {
// using the OpenAI API.
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
// OpenAI embeddings are normalized
normalized := true
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
return NewEmbeddingFuncOpenAICompat(NewOpenAICompatConfig(BaseURLOpenAI, apiKey, string(model)).WithNormalized(true))
}

// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
Expand All @@ -61,31 +60,30 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding
// model are already normalized, as is the case for OpenAI's and Mistral's models.
// The flag is optional. If it's nil, it will be autodetected on the first request
// (which bears a small risk that the vector just happens to have a length of 1).
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, opts ...OpenAICompatOption) EmbeddingFunc {
func NewEmbeddingFuncOpenAICompat(config *openAICompatConfig) EmbeddingFunc {
if config == nil {
panic("config must not be nil")
}

// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

cfg := DefaultOpenAICompatOptions()
for _, opt := range opts {
opt(cfg)
}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
"input": text,
"model": model,
"model": config.model,
})
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

fullURL, err := url.JoinPath(baseURL, cfg.EmbeddingsEndpoint)
fullURL, err := url.JoinPath(config.baseURL, config.embeddingsEndpoint)
if err != nil {
return nil, fmt.Errorf("couldn't join base URL and endpoint: %w", err)
}
Expand All @@ -97,16 +95,16 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
return nil, fmt.Errorf("couldn't create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Authorization", "Bearer "+config.apiKey)

// Add headers
for k, v := range cfg.Headers {
for k, v := range config.headers {
req.Header.Add(k, v)
}

// Add query parameters
q := req.URL.Query()
for k, v := range cfg.QueryParams {
for k, v := range config.queryParams {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
Expand Down Expand Up @@ -140,8 +138,8 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
}

v := embeddingResponse.Data[0].Embedding
if normalized != nil {
if *normalized {
if config.normalized != nil {
if *config.normalized {
return v, nil
}
return normalizeVector(v), nil
Expand All @@ -161,34 +159,44 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
}
}

type OpenAICompatOptions struct {
EmbeddingsEndpoint string
Headers map[string]string
QueryParams map[string]string
type openAICompatConfig struct {
baseURL string
apiKey string
model string

// Optional
normalized *bool
embeddingsEndpoint string
headers map[string]string
queryParams map[string]string
}

type OpenAICompatOption func(*OpenAICompatOptions)
func NewOpenAICompatConfig(baseURL, apiKey, model string) *openAICompatConfig {
return &openAICompatConfig{
baseURL: baseURL,
apiKey: apiKey,
model: model,

func WithOpenAICompatEmbeddingsEndpointOverride(endpoint string) OpenAICompatOption {
return func(o *OpenAICompatOptions) {
o.EmbeddingsEndpoint = endpoint
embeddingsEndpoint: "/embeddings",
}
}

func WithOpenAICompatHeaders(headers map[string]string) OpenAICompatOption {
return func(o *OpenAICompatOptions) {
o.Headers = headers
}
func (c *openAICompatConfig) WithEmbeddingsEndpoint(endpoint string) *openAICompatConfig {
c.embeddingsEndpoint = endpoint
return c
}

func WithOpenAICompatQueryParams(queryParams map[string]string) OpenAICompatOption {
return func(o *OpenAICompatOptions) {
o.QueryParams = queryParams
}
func (c *openAICompatConfig) WithHeaders(headers map[string]string) *openAICompatConfig {
c.headers = headers
return c
}

func DefaultOpenAICompatOptions() *OpenAICompatOptions {
return &OpenAICompatOptions{
EmbeddingsEndpoint: "/embeddings",
}
func (c *openAICompatConfig) WithQueryParams(queryParams map[string]string) *openAICompatConfig {
c.queryParams = queryParams
return c
}

func (c *openAICompatConfig) WithNormalized(normalized bool) *openAICompatConfig {
c.normalized = &normalized
return c
}
2 changes: 1 addition & 1 deletion embed_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
defer ts.Close()
baseURL := ts.URL + baseURLSuffix

f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil)
f := chromem.NewEmbeddingFuncOpenAICompat(chromem.NewOpenAICompatConfig(baseURL, apiKey, model))
res, err := f(context.Background(), input)
if err != nil {
t.Fatal("expected nil, got", err)
Expand Down

0 comments on commit d838df0

Please sign in to comment.