Skip to content

Commit 30b5de5

Browse files
authored
genai: omit empty text parts from session history (#226)
When adding to session history, remove empty text parts. This is the same change as googleapis/google-cloud-go#10362. It fixes TestLive/tools/direct.
1 parent ae7597a commit 30b5de5

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

genai/chat.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,20 @@ func (cs *ChatSession) addToHistory(cands []*Candidate) bool {
7070
return false
7171
}
7272
c.Role = roleModel
73-
cs.History = append(cs.History, c)
73+
cs.History = append(cs.History, copySanitizedModelContent(c))
7474
return true
7575
}
7676
return false
7777
}
78+
79+
// copySanitizedModelContent creates a (shallow) copy of c with role set to
80+
// model and empty text parts removed.
81+
func copySanitizedModelContent(c *Content) *Content {
82+
newc := &Content{Role: roleModel}
83+
for _, part := range c.Parts {
84+
if t, ok := part.(Text); !ok || len(string(t)) > 0 {
85+
newc.Parts = append(newc.Parts, part)
86+
}
87+
}
88+
return newc
89+
}

genai/client_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"testing"
3131
"time"
3232

33+
"github.com/googleapis/gax-go/v2/apierror"
3334
"google.golang.org/api/googleapi"
3435
"google.golang.org/api/iterator"
3536
"google.golang.org/api/option"
@@ -391,13 +392,17 @@ func TestLive(t *testing.T) {
391392
if c := "Mountain View"; !strings.Contains(locArg, c) {
392393
t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, locArg, c)
393394
}
394-
res, err = session.SendMessage(ctx, FunctionResponse{
395+
res, err = session.SendMessage(ctx, Text("response:"), FunctionResponse{
395396
Name: movieTool.FunctionDeclarations[0].Name,
396397
Response: map[string]any{
397398
"theater": "AMC16",
398399
},
399400
})
400401
if err != nil {
402+
if ae, ok := err.(*apierror.APIError); ok {
403+
t.Fatal(ae.Unwrap())
404+
405+
}
401406
t.Fatal(err)
402407
}
403408
checkMatch(t, responseString(res), "AMC")

0 commit comments

Comments
 (0)