Skip to content

Commit

Permalink
genai: add debug printing of requests (#178)
Browse files Browse the repository at this point in the history
Sometimes it is useful to see the requests that are being
sent to the service. Add a facility for this.
  • Loading branch information
jba authored Jul 16, 2024
1 parent e0b57b6 commit f996f0d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 9 deletions.
12 changes: 8 additions & 4 deletions genai/caching.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*C
}
pcc := cc.toProto()
pcc.Model = Ptr(fullModelName(cc.Model))
return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, &pb.CreateCachedContentRequest{
req := &pb.CreateCachedContentRequest{
CachedContent: pcc,
}))
}
debugPrint(req)
return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, req))
}

// GetCachedContent retrieves the CachedContent with the given name.
Expand Down Expand Up @@ -108,10 +110,12 @@ func (c *Client) UpdateCachedContent(ctx context.Context, cc *CachedContent, ccu
if ccu.Expiration.ExpireTime.IsZero() {
mask = "ttl"
}
return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, &pb.UpdateCachedContentRequest{
req := &pb.UpdateCachedContentRequest{
CachedContent: cc2.toProto(),
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{mask}},
}))
}
debugPrint(req)
return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, req))
}

// ListCachedContents lists all the CachedContents associated with the project and location.
Expand Down
11 changes: 8 additions & 3 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G
if m.CachedContentName != "" {
cc = &m.CachedContentName
}
return &pb.GenerateContentRequest{
req := &pb.GenerateContentRequest{
Model: m.fullName,
Contents: transformSlice(contents, (*Content).toProto),
SafetySettings: transformSlice(m.SafetySettings, (*SafetySetting).toProto),
Expand All @@ -236,6 +236,8 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G
SystemInstruction: m.SystemInstruction.toProto(),
CachedContent: cc,
}
debugPrint(req)
return req
})
}

Expand Down Expand Up @@ -327,10 +329,12 @@ func (m *GenerativeModel) newCountTokensRequest(contents ...*Content) (*pb.Count
if err != nil {
return nil, err
}
return &pb.CountTokensRequest{
req := &pb.CountTokensRequest{
Model: m.fullName,
GenerateContentRequest: gcr,
}, nil
}
debugPrint(req)
return req, nil
}

// Info returns information about the model.
Expand All @@ -340,6 +344,7 @@ func (m *GenerativeModel) Info(ctx context.Context) (*ModelInfo, error) {

func (c *Client) modelInfo(ctx context.Context, fullName string) (*ModelInfo, error) {
req := &pb.GetModelRequest{Name: fullName}
debugPrint(req)
res, err := c.mc.GetModel(ctx, req)
if err != nil {
return nil, err
Expand Down
38 changes: 38 additions & 0 deletions genai/debug.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file contains debugging support functions.

package genai

import (
"fmt"
"os"

"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
)

// printRequests controls whether request protobufs are written to stderr.
var printRequests = false

func debugPrint(m proto.Message) {
if !printRequests {
return
}
fmt.Fprintln(os.Stderr, "--------")
fmt.Fprintf(os.Stderr, "%T\n", m)
fmt.Fprint(os.Stderr, prototext.Format(m))
fmt.Fprintln(os.Stderr, "^^^^^^^^")
}
1 change: 1 addition & 0 deletions genai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func newEmbedContentRequest(model string, tt TaskType, title string, parts []Par
taskType := pb.TaskType(tt)
req.TaskType = &taskType
}
debugPrint(req)
return req
}

Expand Down
8 changes: 6 additions & 2 deletions genai/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ func (c *Client) UploadFile(ctx context.Context, name string, r io.Reader, opts

// GetFile returns the named file.
func (c *Client) GetFile(ctx context.Context, name string) (*File, error) {
pf, err := c.fc.GetFile(ctx, &pb.GetFileRequest{Name: userNameToServiceName(name)})
req := &pb.GetFileRequest{Name: userNameToServiceName(name)}
debugPrint(req)
pf, err := c.fc.GetFile(ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -84,7 +86,9 @@ func (c *Client) GetFile(ctx context.Context, name string) (*File, error) {
// DeleteFile deletes the file with the given name.
// It is an error to delete a file that does not exist.
func (c *Client) DeleteFile(ctx context.Context, name string) error {
return c.fc.DeleteFile(ctx, &pb.DeleteFileRequest{Name: userNameToServiceName(name)})
req := &pb.DeleteFileRequest{Name: userNameToServiceName(name)}
debugPrint(req)
return c.fc.DeleteFile(ctx, req)
}

// userNameToServiceName converts a name supplied by the user to a name required by the service.
Expand Down

0 comments on commit f996f0d

Please sign in to comment.