Skip to content

support embeddings #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type GenerativeModel struct {

const defaultMaxOutputTokens = 2048

// GenerativeModel creates a new instance of the named model.
// GenerativeModel creates a new instance of the named generative model.
func (c *Client) GenerativeModel(name string) *GenerativeModel {
return &GenerativeModel{
GenerationConfig: GenerationConfig{
Expand Down
18 changes: 18 additions & 0 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ func TestLive(t *testing.T) {
}
checkMatch(t, responseString(res), "(it's}|weather) .*cold")
})
t.Run("embed", func(t *testing.T) {
em := client.EmbeddingModel("embedding-001")
res, err := em.EmbedContent(ctx, Text("cheddar cheese"))
if err != nil {
t.Fatal(err)
}
if res == nil || res.Embedding == nil || len(res.Embedding.Values) < 10 {
t.Errorf("bad result: %v\n", res)
}

res, err = em.EmbedContentWithTitle(ctx, "My Cheese Report", Text("I love cheddar cheese."))
if err != nil {
t.Fatal(err)
}
if res == nil || res.Embedding == nil || len(res.Embedding.Values) < 10 {
t.Errorf("bad result: %v\n", res)
}
})
}

func TestJoinResponses(t *testing.T) {
Expand Down
14 changes: 11 additions & 3 deletions genai/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ types:
Candidate_FinishReason:
name: FinishReason
protoPrefix: Candidate_
fields:
GroundingAttributions:
omit: true

GenerateContentResponse_PromptFeedback_BlockReason:
name: BlockReason
Expand Down Expand Up @@ -114,6 +111,17 @@ types:

Schema:

TaskType:
protoPrefix: TaskType
valueNames:
TaskType_TASK_TYPE_UNSPECIFIED: TaskTypeUnspecified

EmbedContentResponse:

ContentEmbedding:




# Omit everything not explicitly configured.
omitTypes:
Expand Down
71 changes: 71 additions & 0 deletions genai/embed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package genai

import (
"context"
"fmt"

pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb"
)

// EmbeddingModel creates a new instance of the named embedding model.
func (c *Client) EmbeddingModel(name string) *EmbeddingModel {
return &EmbeddingModel{
c: c,
name: name,
fullName: fmt.Sprintf("models/%s", name),
}
}

// EmbeddingModel is a model that computes embeddings.
// Create one with [Client.EmbeddingModel].
type EmbeddingModel struct {
c *Client
name string
fullName string
// TaskType describes how the embedding will be used.
TaskType TaskType
}

// EmbedContent returns an embedding for the list of parts.
func (m *EmbeddingModel) EmbedContent(ctx context.Context, parts ...Part) (*EmbedContentResponse, error) {
return m.EmbedContentWithTitle(ctx, "", parts...)
}

// EmbedContentWithTitle returns an embedding for the list of parts.
// If the given title is non-empty, it is passed to the model and
// the task type is set to TaskTypeRetrievalDocument.
func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string, parts ...Part) (*EmbedContentResponse, error) {
req := &pb.EmbedContentRequest{
Model: m.fullName,
Content: newUserContent(parts).toProto(),
}
// A non-empty title overrides the task type.
var tt TaskType
if title != "" {
req.Title = &title
tt = TaskTypeRetrievalDocument
}
if tt != TaskTypeUnspecified {
taskType := pb.TaskType(tt)
req.TaskType = &taskType
}
res, err := m.c.c.EmbedContent(ctx, req)
if err != nil {
return nil, err
}
return (EmbedContentResponse{}).fromProto(res), nil
}
82 changes: 82 additions & 0 deletions genai/generativelanguagepb_veneer.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,30 @@ func (Content) fromProto(p *pb.Content) *Content {
}
}

// ContentEmbedding is a list of floats representing an embedding.
type ContentEmbedding struct {
// The embedding values.
Values []float32
}

func (v *ContentEmbedding) toProto() *pb.ContentEmbedding {
if v == nil {
return nil
}
return &pb.ContentEmbedding{
Values: v.Values,
}
}

func (ContentEmbedding) fromProto(p *pb.ContentEmbedding) *ContentEmbedding {
if p == nil {
return nil
}
return &ContentEmbedding{
Values: p.Values,
}
}

// CountTokensResponse is a response from `CountTokens`.
//
// It returns the model's `token_count` for the `prompt`.
Expand Down Expand Up @@ -263,6 +287,30 @@ func (CountTokensResponse) fromProto(p *pb.CountTokensResponse) *CountTokensResp
}
}

// EmbedContentResponse is the response to an `EmbedContentRequest`.
type EmbedContentResponse struct {
// Output only. The embedding generated from the input content.
Embedding *ContentEmbedding
}

func (v *EmbedContentResponse) toProto() *pb.EmbedContentResponse {
if v == nil {
return nil
}
return &pb.EmbedContentResponse{
Embedding: v.Embedding.toProto(),
}
}

func (EmbedContentResponse) fromProto(p *pb.EmbedContentResponse) *EmbedContentResponse {
if p == nil {
return nil
}
return &EmbedContentResponse{
Embedding: (ContentEmbedding{}).fromProto(p.Embedding),
}
}

// FinishReason is defines the reason why the model stopped generating tokens.
type FinishReason int32

Expand Down Expand Up @@ -763,6 +811,40 @@ func (Schema) fromProto(p *pb.Schema) *Schema {
}
}

// TaskType is type of task for which the embedding will be used.
type TaskType int32

const (
// TaskTypeUnspecified means unset value, which will default to one of the other enum values.
TaskTypeUnspecified TaskType = 0
// TaskTypeRetrievalQuery means specifies the given text is a query in a search/retrieval setting.
TaskTypeRetrievalQuery TaskType = 1
// TaskTypeRetrievalDocument means specifies the given text is a document from the corpus being searched.
TaskTypeRetrievalDocument TaskType = 2
// TaskTypeSemanticSimilarity means specifies the given text will be used for STS.
TaskTypeSemanticSimilarity TaskType = 3
// TaskTypeClassification means specifies that the given text will be classified.
TaskTypeClassification TaskType = 4
// TaskTypeClustering means specifies that the embeddings will be used for clustering.
TaskTypeClustering TaskType = 5
)

var namesForTaskType = map[TaskType]string{
TaskTypeUnspecified: "TaskTypeUnspecified",
TaskTypeRetrievalQuery: "TaskTypeRetrievalQuery",
TaskTypeRetrievalDocument: "TaskTypeRetrievalDocument",
TaskTypeSemanticSimilarity: "TaskTypeSemanticSimilarity",
TaskTypeClassification: "TaskTypeClassification",
TaskTypeClustering: "TaskTypeClustering",
}

func (v TaskType) String() string {
if n, ok := namesForTaskType[v]; ok {
return n
}
return fmt.Sprintf("TaskType(%d)", v)
}

// Tool details that the model may use to generate response.
//
// A `Tool` is a piece of code that enables the system to interact with
Expand Down
6 changes: 5 additions & 1 deletion internal/cmd/protoveneer/protoveneer.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,11 @@ func identName(x any) string {
func snakeToCamelCase(s string) string {
words := strings.Split(s, "_")
for i, w := range words {
words[i] = fmt.Sprintf("%c%s", unicode.ToUpper(rune(w[0])), strings.ToLower(w[1:]))
if len(w) == 0 {
words[i] = w
} else {
words[i] = fmt.Sprintf("%c%s", unicode.ToUpper(rune(w[0])), strings.ToLower(w[1:]))
}
}
return strings.Join(words, "")
}
Expand Down