Skip to content

Commit

Permalink
genai: omit empty text parts from session history
Browse files Browse the repository at this point in the history
When adding to session history, remove empty text parts.

This is the same change as googleapis/google-cloud-go#10362.

It fixes TestLive/tools/direct.
  • Loading branch information
jba committed Nov 18, 2024
1 parent ae7597a commit d1fc99f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
14 changes: 13 additions & 1 deletion genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,20 @@ func (cs *ChatSession) addToHistory(cands []*Candidate) bool {
return false
}
c.Role = roleModel
cs.History = append(cs.History, c)
cs.History = append(cs.History, copySanitizedModelContent(c))
return true
}
return false
}

// copySanitizedModelContent creates a (shallow) copy of c with role set to
// model and empty text parts removed.
func copySanitizedModelContent(c *Content) *Content {
newc := &Content{Role: roleModel}
for _, part := range c.Parts {
if t, ok := part.(Text); !ok || len(string(t)) > 0 {
newc.Parts = append(newc.Parts, part)
}
}
return newc
}
7 changes: 6 additions & 1 deletion genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"testing"
"time"

"github.com/googleapis/gax-go/v2/apierror"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
Expand Down Expand Up @@ -391,13 +392,17 @@ func TestLive(t *testing.T) {
if c := "Mountain View"; !strings.Contains(locArg, c) {
t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, locArg, c)
}
res, err = session.SendMessage(ctx, FunctionResponse{
res, err = session.SendMessage(ctx, Text("response:"), FunctionResponse{
Name: movieTool.FunctionDeclarations[0].Name,
Response: map[string]any{
"theater": "AMC16",
},
})
if err != nil {
if ae, ok := err.(*apierror.APIError); ok {
t.Fatal(ae.Unwrap())

}
t.Fatal(err)
}
checkMatch(t, responseString(res), "AMC")
Expand Down

0 comments on commit d1fc99f

Please sign in to comment.