From a146ce68fefdbabba983c2bb4acf97c435956de4 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 24 Jul 2024 11:33:39 +0200 Subject: [PATCH] change: switch to builder pattern for options --- embed_vertex.go | 55 ++++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/embed_vertex.go b/embed_vertex.go index d800894..6efc667 100644 --- a/embed_vertex.go +++ b/embed_vertex.go @@ -25,30 +25,34 @@ const ( const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1" -type VertexOptions struct { - APIEndpoint string - AutoTruncate bool +type vertexConfig struct { + apiKey string + project string + model EmbeddingModelVertex + + // Optional + apiEndpoint string + autoTruncate bool } -func DefaultVertexOptions() *VertexOptions { - return &VertexOptions{ - APIEndpoint: baseURLVertex, - AutoTruncate: false, +func NewVertexConfig(apiKey, project string, model EmbeddingModelVertex) *vertexConfig { + return &vertexConfig{ + apiKey: apiKey, + project: project, + model: model, + apiEndpoint: baseURLVertex, + autoTruncate: false, } } -type VertexOption func(*VertexOptions) - -func WithVertexAPIEndpoint(apiEndpoint string) VertexOption { - return func(o *VertexOptions) { - o.APIEndpoint = apiEndpoint - } +func (c *vertexConfig) WithAPIEndpoint(apiEndpoint string) *vertexConfig { + c.apiEndpoint = apiEndpoint + return c } -func WithVertexAutoTruncate(autoTruncate bool) VertexOption { - return func(o *VertexOptions) { - o.AutoTruncate = autoTruncate - } +func (c *vertexConfig) WithAutoTruncate(autoTruncate bool) *vertexConfig { + c.autoTruncate = autoTruncate + return c } type vertexResponse struct { @@ -64,16 +68,7 @@ type vertexEmbeddings struct { // there's more here, but we only care about the embeddings } -func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts ...VertexOption) EmbeddingFunc { - - cfg := DefaultVertexOptions() - for _, opt := range opts { - opt(cfg) - } - - if cfg.APIEndpoint == "" { - cfg.APIEndpoint = baseURLVertex - } +func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, @@ -92,7 +87,7 @@ func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, }, }, "parameters": map[string]any{ - "autoTruncate": cfg.AutoTruncate, + "autoTruncate": config.autoTruncate, }, } @@ -102,7 +97,7 @@ func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, return nil, fmt.Errorf("couldn't marshal request body: %w", err) } - fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.APIEndpoint, project, model) + fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", config.apiEndpoint, config.project, config.model) // Create the request. Creating it with context is important for a timeout // to be possible, because the client is configured without a timeout. @@ -112,7 +107,7 @@ func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, } req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Authorization", "Bearer "+config.apiKey) // Send the request. resp, err := client.Do(req)