Skip to content

Commit 359cfe8

Browse files
authored
genai: support constrained decoding (#82)
Let the user determine how the model should respond with FunctionCalls.
1 parent 78aace2 commit 359cfe8

7 files changed

+277
-46
lines changed

genai/client.go

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ type GenerativeModel struct {
105105
GenerationConfig
106106
SafetySettings []*SafetySetting
107107
Tools []*Tool
108+
ToolConfig *ToolConfig
108109
}
109110

110111
// GenerativeModel creates a new instance of the named generative model.
@@ -168,6 +169,7 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) *pb.Ge
168169
Contents: support.TransformSlice(contents, (*Content).toProto),
169170
SafetySettings: support.TransformSlice(m.SafetySettings, (*SafetySetting).toProto),
170171
Tools: support.TransformSlice(m.Tools, (*Tool).toProto),
172+
ToolConfig: m.ToolConfig.toProto(),
171173
GenerationConfig: m.GenerationConfig.toProto(),
172174
}
173175
}

genai/client_test.go

+30-16
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ func TestLive(t *testing.T) {
297297
})
298298
t.Run("tools", func(t *testing.T) {
299299

300-
weatherChat := func(t *testing.T, s *Schema) {
300+
weatherChat := func(t *testing.T, s *Schema, fcm FunctionCallingMode) {
301301
weatherTool := &Tool{
302302
FunctionDeclarations: []*FunctionDeclaration{{
303303
Name: "CurrentWeather",
@@ -308,12 +308,23 @@ func TestLive(t *testing.T) {
308308
model := client.GenerativeModel(*modelName)
309309
model.SetTemperature(0)
310310
model.Tools = []*Tool{weatherTool}
311+
model.ToolConfig = &ToolConfig{
312+
FunctionCallingConfig: &FunctionCallingConfig{
313+
Mode: fcm,
314+
},
315+
}
311316
session := model.StartChat()
312317
res, err := session.SendMessage(ctx, Text("What is the weather like in New York?"))
313318
if err != nil {
314319
t.Fatal(err)
315320
}
316321
funcalls := res.Candidates[0].FunctionCalls()
322+
if fcm == FunctionCallingNone {
323+
if len(funcalls) != 0 {
324+
t.Fatalf("got %d FunctionCalls, want 0", len(funcalls))
325+
}
326+
return
327+
}
317328
if len(funcalls) != 1 {
318329
t.Fatalf("got %d FunctionCalls, want 1", len(funcalls))
319330
}
@@ -339,22 +350,25 @@ func TestLive(t *testing.T) {
339350
}
340351
checkMatch(t, responseString(res), "(it's|it is|weather) .*cold")
341352
}
342-
343-
t.Run("direct", func(t *testing.T) {
344-
weatherChat(t, &Schema{
345-
Type: TypeObject,
346-
Properties: map[string]*Schema{
347-
"location": {
348-
Type: TypeString,
349-
Description: "The city and state, e.g. San Francisco, CA",
350-
},
351-
"unit": {
352-
Type: TypeString,
353-
Enum: []string{"celsius", "fahrenheit"},
354-
},
353+
schema := &Schema{
354+
Type: TypeObject,
355+
Properties: map[string]*Schema{
356+
"location": {
357+
Type: TypeString,
358+
Description: "The city and state, e.g. San Francisco, CA",
355359
},
356-
Required: []string{"location"},
357-
})
360+
"unit": {
361+
Type: TypeString,
362+
Enum: []string{"celsius", "fahrenheit"},
363+
},
364+
},
365+
Required: []string{"location"},
366+
}
367+
t.Run("direct", func(t *testing.T) {
368+
weatherChat(t, schema, FunctionCallingAuto)
369+
})
370+
t.Run("none", func(t *testing.T) {
371+
weatherChat(t, schema, FunctionCallingNone)
358372
})
359373
})
360374
}

genai/config.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ types:
117117

118118
# Types for function calling
119119
Tool:
120+
ToolConfig:
120121
FunctionDeclaration:
121122
FunctionCall:
122123
FunctionResponse:
@@ -126,6 +127,15 @@ types:
126127
protoPrefix: Type_
127128
veneerPrefix: ''
128129

130+
FunctionCallingConfig:
131+
doc: 'holds configuration for function calling.'
132+
133+
FunctionCallingConfig_Mode:
134+
name: FunctionCallingMode
135+
protoPrefix: FunctionCallingConfig
136+
veneerPrefix: FunctionCalling
137+
valueNames:
138+
FunctionCallingConfig_MODE_UNSPECIFIED: FunctionCallingUnspecified
129139

130140

131141

genai/example_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,40 @@ func ExampleTool() {
394394
printResponse(res)
395395
}
396396

397+
func ExampleToolConifg() {
398+
// This example shows how to affect how the model uses the tools provided to it.
399+
// By setting the ToolConfig, you can disable function calling.
400+
401+
// Assume we have created a Model and have set its Tools field with some functions.
402+
// See the Example for Tool for details.
403+
var model *genai.GenerativeModel
404+
405+
// By default, the model will use the functions in its responses if it thinks they are
406+
// relevant, by returning FunctionCall parts.
407+
// Here we set the model's ToolConfig to disable function calling completely.
408+
model.ToolConfig = &genai.ToolConfig{
409+
FunctionCallingConfig: &genai.FunctionCallingConfig{
410+
Mode: genai.FunctionCallingNone,
411+
},
412+
}
413+
414+
// Subsequent calls to ChatSession.SendMessage will not result in FunctionCall responses.
415+
session := model.StartChat()
416+
res, err := session.SendMessage(context.Background(), genai.Text("What is the weather like in New York?"))
417+
if err != nil {
418+
log.Fatal(err)
419+
}
420+
for _, part := range res.Candidates[0].Content.Parts {
421+
if _, ok := part.(genai.FunctionCall); ok {
422+
log.Fatal("did not expect FunctionCall")
423+
}
424+
}
425+
426+
// It is also possible to force a function call by using FunctionCallingAny
427+
// instead of FunctionCallingNone. See the documentation for FunctionCallingMode
428+
// for details.
429+
}
430+
397431
func printResponse(resp *genai.GenerateContentResponse) {
398432
for _, cand := range resp.Candidates {
399433
if cand.Content != nil {

genai/generativelanguagepb_veneer.gen.go

+103-13
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,74 @@ func (FunctionCall) fromProto(p *pb.FunctionCall) *FunctionCall {
400400
}
401401
}
402402

403+
// FunctionCallingConfig holds configuration for function calling.
404+
type FunctionCallingConfig struct {
405+
// Optional. Specifies the mode in which function calling should execute. If
406+
// unspecified, the default value will be set to AUTO.
407+
Mode FunctionCallingMode
408+
// Optional. A set of function names that, when provided, limits the functions
409+
// the model will call.
410+
//
411+
// This should only be set when the Mode is ANY. Function names
412+
// should match [FunctionDeclaration.name]. With mode set to ANY, model will
413+
// predict a function call from the set of function names provided.
414+
AllowedFunctionNames []string
415+
}
416+
417+
func (v *FunctionCallingConfig) toProto() *pb.FunctionCallingConfig {
418+
if v == nil {
419+
return nil
420+
}
421+
return &pb.FunctionCallingConfig{
422+
Mode: pb.FunctionCallingConfig_Mode(v.Mode),
423+
AllowedFunctionNames: v.AllowedFunctionNames,
424+
}
425+
}
426+
427+
func (FunctionCallingConfig) fromProto(p *pb.FunctionCallingConfig) *FunctionCallingConfig {
428+
if p == nil {
429+
return nil
430+
}
431+
return &FunctionCallingConfig{
432+
Mode: FunctionCallingMode(p.Mode),
433+
AllowedFunctionNames: p.AllowedFunctionNames,
434+
}
435+
}
436+
437+
// FunctionCallingMode is defines the execution behavior for function calling by defining the
438+
// execution mode.
439+
type FunctionCallingMode int32
440+
441+
const (
442+
// FunctionCallingUnspecified means unspecified function calling mode. This value should not be used.
443+
FunctionCallingUnspecified FunctionCallingMode = 0
444+
// FunctionCallingAuto means default model behavior, model decides to predict either a function call
445+
// or a natural language repspose.
446+
FunctionCallingAuto FunctionCallingMode = 1
447+
// FunctionCallingAny means model is constrained to always predicting a function call only.
448+
// If "allowed_function_names" are set, the predicted function call will be
449+
// limited to any one of "allowed_function_names", else the predicted
450+
// function call will be any one of the provided "function_declarations".
451+
FunctionCallingAny FunctionCallingMode = 2
452+
// FunctionCallingNone means model will not predict any function call. Model behavior is same as when
453+
// not passing any function declarations.
454+
FunctionCallingNone FunctionCallingMode = 3
455+
)
456+
457+
var namesForFunctionCallingMode = map[FunctionCallingMode]string{
458+
FunctionCallingUnspecified: "FunctionCallingUnspecified",
459+
FunctionCallingAuto: "FunctionCallingAuto",
460+
FunctionCallingAny: "FunctionCallingAny",
461+
FunctionCallingNone: "FunctionCallingNone",
462+
}
463+
464+
func (v FunctionCallingMode) String() string {
465+
if n, ok := namesForFunctionCallingMode[v]; ok {
466+
return n
467+
}
468+
return fmt.Sprintf("FunctionCallingMode(%d)", v)
469+
}
470+
403471
// FunctionDeclaration is structured representation of a function declaration as defined by the
404472
// [OpenAPI 3.03 specification](https://spec.openapis.org/oas/v3.0.3). Included
405473
// in this declaration are the function name and parameters. This
@@ -517,7 +585,7 @@ func (GenerateContentResponse) fromProto(p *pb.GenerateContentResponse) *Generat
517585
type GenerationConfig struct {
518586
// Optional. Number of generated responses to return.
519587
//
520-
// This value must be between [1, 8], inclusive. If unset, this will default
588+
// Currently, this value can only be set to 1. If unset, this will default
521589
// to 1.
522590
CandidateCount *int32
523591
// Optional. The set of character sequences (up to 5) that will stop output
@@ -527,17 +595,15 @@ type GenerationConfig struct {
527595
StopSequences []string
528596
// Optional. The maximum number of tokens to include in a candidate.
529597
//
530-
// If unset, this will default to output_token_limit specified in the `Model`
531-
// specification.
598+
// Note: The default value varies by model, see the `Model.output_token_limit`
599+
// attribute of the `Model` returned from the `getModel` function.
532600
MaxOutputTokens *int32
533601
// Optional. Controls the randomness of the output.
602+
//
534603
// Note: The default value varies by model, see the `Model.temperature`
535-
// attribute of the `Model` returned the `getModel` function.
604+
// attribute of the `Model` returned from the `getModel` function.
536605
//
537-
// Values can range from [0.0,1.0],
538-
// inclusive. A value closer to 1.0 will produce responses that are more
539-
// varied and creative, while a value closer to 0.0 will typically result in
540-
// more straightforward responses from the model.
606+
// Values can range from [0.0, infinity).
541607
Temperature *float32
542608
// Optional. The maximum cumulative probability of tokens to consider when
543609
// sampling.
@@ -550,17 +616,16 @@ type GenerationConfig struct {
550616
// of tokens based on the cumulative probability.
551617
//
552618
// Note: The default value varies by model, see the `Model.top_p`
553-
// attribute of the `Model` returned the `getModel` function.
619+
// attribute of the `Model` returned from the `getModel` function.
554620
TopP *float32
555621
// Optional. The maximum number of tokens to consider when sampling.
556622
//
557623
// The model uses combined Top-k and nucleus sampling.
558624
//
559625
// Top-k sampling considers the set of `top_k` most probable tokens.
560-
// Defaults to 40.
561626
//
562627
// Note: The default value varies by model, see the `Model.top_k`
563-
// attribute of the `Model` returned the `getModel` function.
628+
// attribute of the `Model` returned from the `getModel` function.
564629
TopK *int32
565630
}
566631

@@ -634,9 +699,9 @@ const (
634699
HarmCategoryUnspecified HarmCategory = 0
635700
// HarmCategoryDerogatory means negative or harmful comments targeting identity and/or protected attribute.
636701
HarmCategoryDerogatory HarmCategory = 1
637-
// HarmCategoryToxicity means content that is rude, disrepspectful, or profane.
702+
// HarmCategoryToxicity means content that is rude, disrespectful, or profane.
638703
HarmCategoryToxicity HarmCategory = 2
639-
// HarmCategoryViolence means describes scenarios depictng violence against an individual or group, or
704+
// HarmCategoryViolence means describes scenarios depicting violence against an individual or group, or
640705
// general descriptions of gore.
641706
HarmCategoryViolence HarmCategory = 3
642707
// HarmCategorySexual means contains references to sexual acts or other lewd content.
@@ -1044,6 +1109,31 @@ func (Tool) fromProto(p *pb.Tool) *Tool {
10441109
}
10451110
}
10461111

1112+
// ToolConfig is the Tool configuration containing parameters for specifying `Tool` use
1113+
// in the request.
1114+
type ToolConfig struct {
1115+
// Optional. Function calling config.
1116+
FunctionCallingConfig *FunctionCallingConfig
1117+
}
1118+
1119+
func (v *ToolConfig) toProto() *pb.ToolConfig {
1120+
if v == nil {
1121+
return nil
1122+
}
1123+
return &pb.ToolConfig{
1124+
FunctionCallingConfig: v.FunctionCallingConfig.toProto(),
1125+
}
1126+
}
1127+
1128+
func (ToolConfig) fromProto(p *pb.ToolConfig) *ToolConfig {
1129+
if p == nil {
1130+
return nil
1131+
}
1132+
return &ToolConfig{
1133+
FunctionCallingConfig: (FunctionCallingConfig{}).fromProto(p.FunctionCallingConfig),
1134+
}
1135+
}
1136+
10471137
// Type contains the list of OpenAPI data types as defined by
10481138
// https://spec.openapis.org/oas/v3.0.3#data-types
10491139
type Type int32

go.mod

+26-17
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,40 @@ module github.com/google/generative-ai-go
33
go 1.21
44

55
require (
6-
cloud.google.com/go/ai v0.3.0
7-
google.golang.org/api v0.149.0
6+
cloud.google.com/go/ai v0.3.5-0.20240409161017-ce55ad694f21
7+
google.golang.org/api v0.172.0
88
google.golang.org/protobuf v1.33.0
99
)
1010

1111
require (
12-
cloud.google.com/go v0.110.8 // indirect
13-
cloud.google.com/go/compute v1.23.1 // indirect
12+
cloud.google.com/go v0.112.1 // indirect
13+
cloud.google.com/go/compute v1.24.0 // indirect
1414
cloud.google.com/go/compute/metadata v0.2.3 // indirect
15-
cloud.google.com/go/longrunning v0.5.2 // indirect
15+
cloud.google.com/go/longrunning v0.5.6 // indirect
16+
github.com/felixge/httpsnoop v1.0.4 // indirect
17+
github.com/go-logr/logr v1.4.1 // indirect
18+
github.com/go-logr/stdr v1.2.2 // indirect
1619
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
17-
github.com/golang/protobuf v1.5.3 // indirect
20+
github.com/golang/protobuf v1.5.4 // indirect
1821
github.com/google/s2a-go v0.1.7 // indirect
1922
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
20-
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
23+
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
2124
go.opencensus.io v0.24.0 // indirect
22-
golang.org/x/crypto v0.17.0 // indirect
23-
golang.org/x/net v0.17.0 // indirect
24-
golang.org/x/oauth2 v0.13.0 // indirect
25-
golang.org/x/sync v0.4.0 // indirect
26-
golang.org/x/sys v0.15.0 // indirect
25+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
26+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
27+
go.opentelemetry.io/otel v1.24.0 // indirect
28+
go.opentelemetry.io/otel/metric v1.24.0 // indirect
29+
go.opentelemetry.io/otel/trace v1.24.0 // indirect
30+
golang.org/x/crypto v0.21.0 // indirect
31+
golang.org/x/net v0.22.0 // indirect
32+
golang.org/x/oauth2 v0.18.0 // indirect
33+
golang.org/x/sync v0.6.0 // indirect
34+
golang.org/x/sys v0.18.0 // indirect
2735
golang.org/x/text v0.14.0 // indirect
28-
google.golang.org/appengine v1.6.7 // indirect
29-
google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b // indirect
30-
google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect
31-
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect
32-
google.golang.org/grpc v1.59.0 // indirect
36+
golang.org/x/time v0.5.0 // indirect
37+
google.golang.org/appengine v1.6.8 // indirect
38+
google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 // indirect
39+
google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect
40+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa // indirect
41+
google.golang.org/grpc v1.62.1 // indirect
3342
)

0 commit comments

Comments
 (0)