Skip to content

Commit

Permalink
change: switch to builder pattern for options
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed Jul 24, 2024
1 parent 97f3e59 commit a146ce6
Showing 1 changed file with 25 additions and 30 deletions.
55 changes: 25 additions & 30 deletions embed_vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,34 @@ const (

const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1"

type VertexOptions struct {
APIEndpoint string
AutoTruncate bool
type vertexConfig struct {
apiKey string
project string
model EmbeddingModelVertex

// Optional
apiEndpoint string
autoTruncate bool
}

func DefaultVertexOptions() *VertexOptions {
return &VertexOptions{
APIEndpoint: baseURLVertex,
AutoTruncate: false,
func NewVertexConfig(apiKey, project string, model EmbeddingModelVertex) *vertexConfig {
return &vertexConfig{
apiKey: apiKey,
project: project,
model: model,
apiEndpoint: baseURLVertex,
autoTruncate: false,
}
}

type VertexOption func(*VertexOptions)

func WithVertexAPIEndpoint(apiEndpoint string) VertexOption {
return func(o *VertexOptions) {
o.APIEndpoint = apiEndpoint
}
func (c *vertexConfig) WithAPIEndpoint(apiEndpoint string) *vertexConfig {
c.apiEndpoint = apiEndpoint
return c
}

func WithVertexAutoTruncate(autoTruncate bool) VertexOption {
return func(o *VertexOptions) {
o.AutoTruncate = autoTruncate
}
func (c *vertexConfig) WithAutoTruncate(autoTruncate bool) *vertexConfig {
c.autoTruncate = autoTruncate
return c
}

type vertexResponse struct {
Expand All @@ -64,16 +68,7 @@ type vertexEmbeddings struct {
// there's more here, but we only care about the embeddings
}

func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts ...VertexOption) EmbeddingFunc {

cfg := DefaultVertexOptions()
for _, opt := range opts {
opt(cfg)
}

if cfg.APIEndpoint == "" {
cfg.APIEndpoint = baseURLVertex
}
func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {

// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
Expand All @@ -92,7 +87,7 @@ func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex,
},
},
"parameters": map[string]any{
"autoTruncate": cfg.AutoTruncate,
"autoTruncate": config.autoTruncate,
},
}

Expand All @@ -102,7 +97,7 @@ func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex,
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.APIEndpoint, project, model)
fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", config.apiEndpoint, config.project, config.model)

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
Expand All @@ -112,7 +107,7 @@ func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex,
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Authorization", "Bearer "+config.apiKey)

// Send the request.
resp, err := client.Do(req)
Expand Down

0 comments on commit a146ce6

Please sign in to comment.