Skip to content

Commit

Permalink
vertexai: add NewUserContent helper function (#10570)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliben authored Jul 23, 2024
1 parent ba82942 commit 5c8b73c
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 19 deletions.
6 changes: 3 additions & 3 deletions vertexai/genai/caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ func testCaching(t *testing.T, client *Client) {
argcc := &CachedContent{
Model: model,
Expiration: ExpireTimeOrTTL{TTL: ttl},
Contents: []*Content{{Role: "user", Parts: []Part{
FileData{MIMEType: "text/plain", FileURI: gcsFilePath},
}}},
Contents: []*Content{NewUserContent(FileData{
MIMEType: "text/plain",
FileURI: gcsFilePath})},
}
cc := must(client.CreateCachedContent(ctx, argcc))
compare(cc, wantExpireTime)
Expand Down
4 changes: 2 additions & 2 deletions vertexai/genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *GenerativeModel) StartChat() *ChatSession {
// SendMessage sends a request to the model as part of a chat session.
func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
// Call the underlying client with the entire history plus the argument Content.
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req := cs.m.newGenerateContentRequest(cs.History...)
cc := int32(1)
req.GenerationConfig.CandidateCount = &cc
Expand All @@ -46,7 +46,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat

// SendMessageStream is like SendMessage, but with a streaming request.
func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req := cs.m.newGenerateContentRequest(cs.History...)
var cc int32 = 1
req.GenerationConfig.CandidateCount = &cc
Expand Down
10 changes: 3 additions & 7 deletions vertexai/genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ func (m *GenerativeModel) Name() string {

// GenerateContent produces a single request and response.
func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
return m.generateContent(ctx, m.newGenerateContentRequest(newUserContent(parts)))
return m.generateContent(ctx, m.newGenerateContentRequest(NewUserContent(parts...)))
}

// GenerateContentStream returns an iterator that enumerates responses.
func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
streamClient, err := m.c.pc.StreamGenerateContent(ctx, m.newGenerateContentRequest(newUserContent(parts)))
streamClient, err := m.c.pc.StreamGenerateContent(ctx, m.newGenerateContentRequest(NewUserContent(parts...)))
return &GenerateContentResponseIterator{
sc: streamClient,
err: err,
Expand Down Expand Up @@ -221,10 +221,6 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) *pb.Ge
}
}

func newUserContent(parts []Part) *Content {
return &Content{Role: roleUser, Parts: parts}
}

// GenerateContentResponseIterator is an iterator over GnerateContentResponse.
type GenerateContentResponseIterator struct {
sc pb.PredictionService_StreamGenerateContentClient
Expand Down Expand Up @@ -286,7 +282,7 @@ func protoToResponse(resp *pb.GenerateContentResponse) (*GenerateContentResponse

// CountTokens counts the number of tokens in the content.
func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) {
req := m.newCountTokensRequest(newUserContent(parts))
req := m.newCountTokensRequest(NewUserContent(parts...))
res, err := m.c.pc.CountTokens(ctx, req)
if err != nil {
return nil, err
Expand Down
4 changes: 1 addition & 3 deletions vertexai/genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ func TestLive(t *testing.T) {
t.Run("system-instructions", func(t *testing.T) {
model := client.GenerativeModel(defaultModel)
model.Temperature = Ptr[float32](0)
model.SystemInstruction = &Content{
Parts: []Part{Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = NewUserContent(Text("You are Yoda from Star Wars."))
resp, err := model.GenerateContent(ctx, Text("What is the average size of a swallow?"))
if err != nil {
t.Fatal(err)
Expand Down
10 changes: 10 additions & 0 deletions vertexai/genai/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,13 @@ func (c *Candidate) FunctionCalls() []FunctionCall {
}
return fcs
}

// NewUserContent returns a [Content] with a "user" role set and one or more
// parts.
func NewUserContent(parts ...Part) *Content {
content := &Content{Role: roleUser, Parts: []Part{}}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}
6 changes: 2 additions & 4 deletions vertexai/genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -315,7 +313,7 @@ func ExampleClient_cachedContent() {
file := genai.FileData{MIMEType: "application/pdf", FileURI: "gs://my-bucket/my-doc.pdf"}
cc, err := client.CreateCachedContent(ctx, &genai.CachedContent{
Model: modelName,
Contents: []*genai.Content{{Parts: []genai.Part{file}}},
Contents: []*genai.Content{genai.NewUserContent(file)},
})
model := client.GenerativeModelFromCachedContent(cc)
// Work with the model as usual in this program.
Expand Down

0 comments on commit 5c8b73c

Please sign in to comment.