From 5c8b73c63c4b6aad0af486008318cded525578ef Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Tue, 23 Jul 2024 12:35:39 -0600 Subject: [PATCH] vertexai: add NewUserContent helper function (#10570) --- vertexai/genai/caching_test.go | 6 +++--- vertexai/genai/chat.go | 4 ++-- vertexai/genai/client.go | 10 +++------- vertexai/genai/client_test.go | 4 +--- vertexai/genai/content.go | 10 ++++++++++ vertexai/genai/example_test.go | 6 ++---- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vertexai/genai/caching_test.go b/vertexai/genai/caching_test.go index 3738a9889c6b..af18ba5b8a93 100644 --- a/vertexai/genai/caching_test.go +++ b/vertexai/genai/caching_test.go @@ -103,9 +103,9 @@ func testCaching(t *testing.T, client *Client) { argcc := &CachedContent{ Model: model, Expiration: ExpireTimeOrTTL{TTL: ttl}, - Contents: []*Content{{Role: "user", Parts: []Part{ - FileData{MIMEType: "text/plain", FileURI: gcsFilePath}, - }}}, + Contents: []*Content{NewUserContent(FileData{ + MIMEType: "text/plain", + FileURI: gcsFilePath})}, } cc := must(client.CreateCachedContent(ctx, argcc)) compare(cc, wantExpireTime) diff --git a/vertexai/genai/chat.go b/vertexai/genai/chat.go index 377f06b0c967..0797505d6ba7 100644 --- a/vertexai/genai/chat.go +++ b/vertexai/genai/chat.go @@ -32,7 +32,7 @@ func (m *GenerativeModel) StartChat() *ChatSession { // SendMessage sends a request to the model as part of a chat session. func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) { // Call the underlying client with the entire history plus the argument Content. - cs.History = append(cs.History, newUserContent(parts)) + cs.History = append(cs.History, NewUserContent(parts...)) req := cs.m.newGenerateContentRequest(cs.History...) cc := int32(1) req.GenerationConfig.CandidateCount = &cc @@ -46,7 +46,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat // SendMessageStream is like SendMessage, but with a streaming request. func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator { - cs.History = append(cs.History, newUserContent(parts)) + cs.History = append(cs.History, NewUserContent(parts...)) req := cs.m.newGenerateContentRequest(cs.History...) var cc int32 = 1 req.GenerationConfig.CandidateCount = &cc diff --git a/vertexai/genai/client.go b/vertexai/genai/client.go index 432cc505facd..3beda193cb78 100644 --- a/vertexai/genai/client.go +++ b/vertexai/genai/client.go @@ -187,12 +187,12 @@ func (m *GenerativeModel) Name() string { // GenerateContent produces a single request and response. func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) { - return m.generateContent(ctx, m.newGenerateContentRequest(newUserContent(parts))) + return m.generateContent(ctx, m.newGenerateContentRequest(NewUserContent(parts...))) } // GenerateContentStream returns an iterator that enumerates responses. func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator { - streamClient, err := m.c.pc.StreamGenerateContent(ctx, m.newGenerateContentRequest(newUserContent(parts))) + streamClient, err := m.c.pc.StreamGenerateContent(ctx, m.newGenerateContentRequest(NewUserContent(parts...))) return &GenerateContentResponseIterator{ sc: streamClient, err: err, @@ -221,10 +221,6 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) *pb.Ge } } -func newUserContent(parts []Part) *Content { - return &Content{Role: roleUser, Parts: parts} -} - // GenerateContentResponseIterator is an iterator over GnerateContentResponse. type GenerateContentResponseIterator struct { sc pb.PredictionService_StreamGenerateContentClient @@ -286,7 +282,7 @@ func protoToResponse(resp *pb.GenerateContentResponse) (*GenerateContentResponse // CountTokens counts the number of tokens in the content. func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) { - req := m.newCountTokensRequest(newUserContent(parts)) + req := m.newCountTokensRequest(NewUserContent(parts...)) res, err := m.c.pc.CountTokens(ctx, req) if err != nil { return nil, err diff --git a/vertexai/genai/client_test.go b/vertexai/genai/client_test.go index c590a15c854e..e81c70f782ba 100644 --- a/vertexai/genai/client_test.go +++ b/vertexai/genai/client_test.go @@ -63,9 +63,7 @@ func TestLive(t *testing.T) { t.Run("system-instructions", func(t *testing.T) { model := client.GenerativeModel(defaultModel) model.Temperature = Ptr[float32](0) - model.SystemInstruction = &Content{ - Parts: []Part{Text("You are Yoda from Star Wars.")}, - } + model.SystemInstruction = NewUserContent(Text("You are Yoda from Star Wars.")) resp, err := model.GenerateContent(ctx, Text("What is the average size of a swallow?")) if err != nil { t.Fatal(err) diff --git a/vertexai/genai/content.go b/vertexai/genai/content.go index 0d314b2b55d3..ae01d7bb75f5 100644 --- a/vertexai/genai/content.go +++ b/vertexai/genai/content.go @@ -148,3 +148,13 @@ func (c *Candidate) FunctionCalls() []FunctionCall { } return fcs } + +// NewUserContent returns a [Content] with a "user" role set and one or more +// parts. +func NewUserContent(parts ...Part) *Content { + content := &Content{Role: roleUser, Parts: []Part{}} + for _, part := range parts { + content.Parts = append(content.Parts, part) + } + return content +} diff --git a/vertexai/genai/example_test.go b/vertexai/genai/example_test.go index 97e5da068940..65d151bece7e 100644 --- a/vertexai/genai/example_test.go +++ b/vertexai/genai/example_test.go @@ -74,9 +74,7 @@ func ExampleGenerativeModel_GenerateContent_config() { model.SetTopP(0.5) model.SetTopK(20) model.SetMaxOutputTokens(100) - model.SystemInstruction = &genai.Content{ - Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")}, - } + model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars.")) resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?")) if err != nil { log.Fatal(err) @@ -315,7 +313,7 @@ func ExampleClient_cachedContent() { file := genai.FileData{MIMEType: "application/pdf", FileURI: "gs://my-bucket/my-doc.pdf"} cc, err := client.CreateCachedContent(ctx, &genai.CachedContent{ Model: modelName, - Contents: []*genai.Content{{Parts: []genai.Part{file}}}, + Contents: []*genai.Content{genai.NewUserContent(file)}, }) model := client.GenerativeModelFromCachedContent(cc) // Work with the model as usual in this program.