Skip to content

Commit

Permalink
Merge pull request #93 from philippgille/revert-to-vertex-variadic-op…
Browse files Browse the repository at this point in the history
…tions

Use variadic functions for Vertex options
  • Loading branch information
philippgille authored Sep 1, 2024
2 parents b5632c0 + 725930f commit fce1e85
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions embed_vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,30 @@ const (

const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1"

type vertexConfig struct {
apiKey string
project string
model EmbeddingModelVertex

// Optional
type vertexOptions struct {
apiEndpoint string
autoTruncate bool
}

func NewVertexConfig(apiKey, project string, model EmbeddingModelVertex) *vertexConfig {
return &vertexConfig{
apiKey: apiKey,
project: project,
model: model,
func defaultVertexOptions() *vertexOptions {
return &vertexOptions{
apiEndpoint: baseURLVertex,
autoTruncate: false,
}
}

func (c *vertexConfig) WithAPIEndpoint(apiEndpoint string) *vertexConfig {
c.apiEndpoint = apiEndpoint
return c
type VertexOption func(*vertexOptions)

func WithVertexAPIEndpoint(apiEndpoint string) VertexOption {
return func(o *vertexOptions) {
o.apiEndpoint = apiEndpoint
}
}

func (c *vertexConfig) WithAutoTruncate(autoTruncate bool) *vertexConfig {
c.autoTruncate = autoTruncate
return c
func WithVertexAutoTruncate(autoTruncate bool) VertexOption {
return func(o *vertexOptions) {
o.autoTruncate = autoTruncate
}
}

type vertexResponse struct {
Expand All @@ -68,7 +64,15 @@ type vertexEmbeddings struct {
// there's more here, but we only care about the embeddings
}

func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts ...VertexOption) EmbeddingFunc {
cfg := defaultVertexOptions()
for _, opt := range opts {
opt(cfg)
}

if cfg.apiEndpoint == "" {
cfg.apiEndpoint = baseURLVertex
}

// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
Expand All @@ -79,15 +83,14 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {

b := map[string]any{
"instances": []map[string]any{
{
"content": text,
},
},
"parameters": map[string]any{
"autoTruncate": config.autoTruncate,
"autoTruncate": cfg.autoTruncate,
},
}

Expand All @@ -97,7 +100,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", config.apiEndpoint, config.project, config.model)
fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.apiEndpoint, project, model)

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
Expand All @@ -107,7 +110,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+config.apiKey)
req.Header.Set("Authorization", "Bearer "+apiKey)

// Send the request.
resp, err := client.Do(req)
Expand Down

0 comments on commit fce1e85

Please sign in to comment.