From 55636238cec0f0cd4be531229964abca6ba0a71e Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 13 Dec 2023 07:52:56 -0500 Subject: [PATCH] add Client.ListModels --- genai/client.go | 9 ++- genai/client_test.go | 34 +++++++- genai/config.yaml | 10 +++ genai/generativelanguagepb_veneer.gen.go | 99 ++++++++++++++++++++++++ genai/list_models.go | 50 ++++++++++++ 5 files changed, 197 insertions(+), 5 deletions(-) create mode 100644 genai/list_models.go diff --git a/genai/client.go b/genai/client.go index ec5ee08..d849517 100644 --- a/genai/client.go +++ b/genai/client.go @@ -33,7 +33,8 @@ import ( // A Client is a Google generative AI client. type Client struct { - c *gl.GenerativeClient + c *gl.GenerativeClient + mc *gl.ModelClient } // NewClient creates a new Google generative AI client. @@ -48,7 +49,11 @@ func NewClient(ctx context.Context, opts ...option.ClientOption) (*Client, error if err != nil { return nil, err } - return &Client{c: c}, nil + mc, err := gl.NewModelRESTClient(ctx, opts...) + if err != nil { + return nil, err + } + return &Client{c: c, mc: mc}, nil } // Close closes the client. diff --git a/genai/client_test.go b/genai/client_test.go index b9decfc..1808103 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -64,7 +64,7 @@ func TestLive(t *testing.T) { t.Run("streaming", func(t *testing.T) { iter := model.GenerateContentStream(ctx, Text("Are you hungry?")) got := responsesString(t, iter) - checkMatch(t, got, `(don't|do\s+not) (have|possess) .*(a .* body|the ability)`) + checkMatch(t, got, `(don't|do\s+not) (have|possess) .*(a .* needs|body|the ability)`) }) t.Run("chat", func(t *testing.T) { @@ -231,7 +231,7 @@ func TestLive(t *testing.T) { if g, w := funcall.Name, weatherTool.FunctionDeclarations[0].Name; g != w { t.Errorf("FunctionCall.Name: got %q, want %q", g, w) } - if g, c := funcall.Args["location"], "New York"; !strings.Contains(g.(string), c) { + if g, c := funcall.Args["location"], "New York"; g != nil && !strings.Contains(g.(string), c) { t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, g, c) } res, err = session.SendMessage(ctx, FunctionResponse{ @@ -243,7 +243,7 @@ func TestLive(t *testing.T) { if err != nil { t.Fatal(err) } - checkMatch(t, responseString(res), "(it's}|weather) .*cold") + checkMatch(t, responseString(res), "(it's|it is|weather) .*cold") }) t.Run("embed", func(t *testing.T) { em := client.EmbeddingModel("embedding-001") @@ -263,6 +263,34 @@ func TestLive(t *testing.T) { t.Errorf("bad result: %v\n", res) } }) + t.Run("list-models", func(t *testing.T) { + iter := client.ListModels(ctx) + var got []*Model + for { + m, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + got = append(got, m) + } + + for _, name := range []string{"gemini-pro", "embedding-001"} { + has := false + fullName := "models/" + name + for _, m := range got { + if m.Name == fullName { + has = true + break + } + } + if !has { + t.Errorf("missing model %q", name) + } + } + }) } func TestJoinResponses(t *testing.T) { diff --git a/genai/config.yaml b/genai/config.yaml index 93375a0..0139940 100644 --- a/genai/config.yaml +++ b/genai/config.yaml @@ -120,6 +120,16 @@ types: ContentEmbedding: + Model: + fields: + BaseModelId: + name: BaseModeID + Temperature: + type: float32 + TopP: + type: float32 + TopK: + type: int32 diff --git a/genai/generativelanguagepb_veneer.gen.go b/genai/generativelanguagepb_veneer.gen.go index f21a362..b71526a 100644 --- a/genai/generativelanguagepb_veneer.gen.go +++ b/genai/generativelanguagepb_veneer.gen.go @@ -650,6 +650,105 @@ func (v HarmProbability) String() string { return fmt.Sprintf("HarmProbability(%d)", v) } +// Model is information about a Generative Language Model. +type Model struct { + // Required. The resource name of the `Model`. + // + // Format: `models/{model}` with a `{model}` naming convention of: + // + // * "{base_model_id}-{version}" + // + // Examples: + // + // * `models/chat-bison-001` + Name string + // Required. The name of the base model, pass this to the generation request. + // + // Examples: + // + // * `chat-bison` + BaseModeID string + // Required. The version number of the model. + // + // This represents the major version + Version string + // The human-readable name of the model. E.g. "Chat Bison". + // + // The name can be up to 128 characters long and can consist of any UTF-8 + // characters. + DisplayName string + // A short description of the model. + Description string + // Maximum number of input tokens allowed for this model. + InputTokenLimit int32 + // Maximum number of output tokens available for this model. + OutputTokenLimit int32 + // The model's supported generation methods. + // + // The method names are defined as Pascal case + // strings, such as `generateMessage` which correspond to API methods. + SupportedGenerationMethods []string + // Controls the randomness of the output. + // + // Values can range over `[0.0,1.0]`, inclusive. A value closer to `1.0` will + // produce responses that are more varied, while a value closer to `0.0` will + // typically result in less surprising responses from the model. + // This value specifies default to be used by the backend while making the + // call to the model. + Temperature float32 + // For Nucleus sampling. + // + // Nucleus sampling considers the smallest set of tokens whose probability + // sum is at least `top_p`. + // This value specifies default to be used by the backend while making the + // call to the model. + TopP float32 + // For Top-k sampling. + // + // Top-k sampling considers the set of `top_k` most probable tokens. + // This value specifies default to be used by the backend while making the + // call to the model. + TopK int32 +} + +func (v *Model) toProto() *pb.Model { + if v == nil { + return nil + } + return &pb.Model{ + Name: v.Name, + BaseModelId: v.BaseModeID, + Version: v.Version, + DisplayName: v.DisplayName, + Description: v.Description, + InputTokenLimit: v.InputTokenLimit, + OutputTokenLimit: v.OutputTokenLimit, + SupportedGenerationMethods: v.SupportedGenerationMethods, + Temperature: support.AddrOrNil(v.Temperature), + TopP: support.AddrOrNil(v.TopP), + TopK: support.AddrOrNil(v.TopK), + } +} + +func (Model) fromProto(p *pb.Model) *Model { + if p == nil { + return nil + } + return &Model{ + Name: p.Name, + BaseModeID: p.BaseModelId, + Version: p.Version, + DisplayName: p.DisplayName, + Description: p.Description, + InputTokenLimit: p.InputTokenLimit, + OutputTokenLimit: p.OutputTokenLimit, + SupportedGenerationMethods: p.SupportedGenerationMethods, + Temperature: support.DerefOrZero(p.Temperature), + TopP: support.DerefOrZero(p.TopP), + TopK: support.DerefOrZero(p.TopK), + } +} + // PromptFeedback contains a set of the feedback metadata the prompt specified in // `GenerateContentRequest.content`. type PromptFeedback struct { diff --git a/genai/list_models.go b/genai/list_models.go new file mode 100644 index 0000000..e9dfa4c --- /dev/null +++ b/genai/list_models.go @@ -0,0 +1,50 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package genai + +import ( + "context" + + gl "cloud.google.com/go/ai/generativelanguage/apiv1beta" + pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" + + "google.golang.org/api/iterator" +) + +func (c *Client) ListModels(ctx context.Context) *ModelIterator { + return &ModelIterator{ + it: c.mc.ListModels(ctx, &pb.ListModelsRequest{}), + } +} + +// A ModelIterator iterates over Models. +type ModelIterator struct { + it *gl.ModelIterator +} + +// Next returns the next result. Its second return value is iterator.Done if there are no more +// results. Once Next returns Done, all subsequent calls will return Done. +func (it *ModelIterator) Next() (*Model, error) { + m, err := it.it.Next() + if err != nil { + return nil, err + } + return (Model{}).fromProto(m), nil +} + +// PageInfo supports pagination. See the google.golang.org/api/iterator package for details. +func (it *ModelIterator) PageInfo() *iterator.PageInfo { + return it.it.PageInfo() +}