Skip to content

Commit

Permalink
genai: factor out uploadFile function for examples (#173)
Browse files Browse the repository at this point in the history
We're going to have several more examples that upload files; factor out
a function that does this properly, checking for status etc. It's like
the function `client_test` uses, but without the `testing.T` scaffold.
  • Loading branch information
eliben authored Jul 13, 2024
1 parent b48f58a commit 608d329
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 16 deletions.
45 changes: 37 additions & 8 deletions genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/google/generative-ai-go/genai"
"github.com/google/generative-ai-go/genai/internal/testhelpers"
Expand All @@ -37,6 +38,38 @@ import (

var testDataDir = filepath.Join(testhelpers.ModuleRootDir(), "genai", "testdata")

// uploadFile uploads the five file to the service, and returns its handle if
// successful.
// To clean up the file, defer a client.DeleteFile(ctx, file.Name)
// call when a file is successfully returned. file.Name will be a uniqely
// generated string to identify the file on the service.
func uploadFile(ctx context.Context, client *genai.Client, filepath string) (*genai.File, error) {
osf, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer osf.Close()

file, err := client.UploadFile(ctx, "", osf, nil)
if err != nil {
return nil, err
}

for file.State == genai.FileStateProcessing {
log.Printf("processing %s", file.Name)
time.Sleep(5 * time.Second)
var err error
file, err = client.GetFile(ctx, file.Name)
if err != nil {
return nil, err
}
}
if file.State != genai.FileStateActive {
return nil, fmt.Errorf("uploaded file has state %s, not active", file.State)
}
return file, nil
}

func ExampleGenerativeModel_GenerateContent() {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
Expand Down Expand Up @@ -203,6 +236,7 @@ func ExampleGenerativeModel_CountTokens_textOnly() {
fmt.Println("candidates_token_count:", resp.UsageMetadata.CandidatesTokenCount)
fmt.Println("total_token_count:", resp.UsageMetadata.TotalTokenCount)
// ( prompt_token_count: 10, candidates_token_count: 38, total_token_count: 48 )

}

func ExampleGenerativeModel_CountTokens_cachedContent() {
Expand Down Expand Up @@ -293,19 +327,14 @@ func ExampleGenerativeModel_CountTokens_imageUploadFile() {

model := client.GenerativeModel("gemini-1.5-flash")
prompt := "Tell me about this image"
imageFile, err := os.Open(filepath.Join(testDataDir, "personWorkingOnComputer.jpg"))
if err != nil {
log.Fatal(err)
}
defer imageFile.Close()

uploadedFile, err := client.UploadFile(ctx, "", imageFile, nil)
file, err := uploadFile(ctx, client, filepath.Join(testDataDir, "personWorkingOnComputer.jpg"))
if err != nil {
log.Fatal(err)
}
defer client.DeleteFile(ctx, file.Name)

fd := genai.FileData{
URI: uploadedFile.URI,
URI: file.URI,
}
// Call `CountTokens` to get the input token count
// of the combined text and file (`total_tokens`).
Expand Down
44 changes: 36 additions & 8 deletions genai/internal/samples/docs-snippets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/google/generative-ai-go/genai"
"github.com/google/generative-ai-go/genai/internal/testhelpers"
Expand All @@ -37,6 +38,38 @@ import (

var testDataDir = filepath.Join(testhelpers.ModuleRootDir(), "genai", "testdata")

// uploadFile uploads the given file to the service, and returns a [genai.File]
// representing it.
// To clean up the file, defer a client.DeleteFile(ctx, file.Name)
// call when a file is successfully returned. file.Name will be a uniqely
// generated string to identify the file on the service.
func uploadFile(ctx context.Context, client *genai.Client, filepath string) (*genai.File, error) {
osf, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer osf.Close()

file, err := client.UploadFile(ctx, "", osf, nil)
if err != nil {
return nil, err
}

for file.State == genai.FileStateProcessing {
log.Printf("processing %s", file.Name)
time.Sleep(5 * time.Second)
var err error
file, err = client.GetFile(ctx, file.Name)
if err != nil {
return nil, err
}
}
if file.State != genai.FileStateActive {
return nil, fmt.Errorf("uploaded file has state %s, not active", file.State)
}
return file, nil
}

func ExampleGenerativeModel_GenerateContent() {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
Expand Down Expand Up @@ -299,19 +332,14 @@ func ExampleGenerativeModel_CountTokens_imageUploadFile() {
// [START tokens_multimodal_image_file_api]
model := client.GenerativeModel("gemini-1.5-flash")
prompt := "Tell me about this image"
imageFile, err := os.Open(filepath.Join(testDataDir, "personWorkingOnComputer.jpg"))
if err != nil {
log.Fatal(err)
}
defer imageFile.Close()

uploadedFile, err := client.UploadFile(ctx, "", imageFile, nil)
file, err := uploadFile(ctx, client, filepath.Join(testDataDir, "personWorkingOnComputer.jpg"))
if err != nil {
log.Fatal(err)
}
defer client.DeleteFile(ctx, file.Name)

fd := genai.FileData{
URI: uploadedFile.URI,
URI: file.URI,
}
// Call `CountTokens` to get the input token count
// of the combined text and file (`total_tokens`).
Expand Down

0 comments on commit 608d329

Please sign in to comment.