Skip to content

Commit 214af16

Browse files
authored
feat(genai): add tokenizer package (#10699)
1 parent f5833e6 commit 214af16

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package tokenizer_test
15+
16+
import (
17+
"fmt"
18+
"log"
19+
20+
"cloud.google.com/go/vertexai/genai"
21+
"cloud.google.com/go/vertexai/genai/tokenizer"
22+
)
23+
24+
func ExampleTokenizer_CountTokens() {
25+
tok, err := tokenizer.New("gemini-1.5-flash")
26+
if err != nil {
27+
log.Fatal(err)
28+
}
29+
30+
ntoks, err := tok.CountTokens(genai.Text("a prompt"), genai.Text("another prompt"))
31+
if err != nil {
32+
log.Fatal(err)
33+
}
34+
35+
fmt.Println("total token count:", ntoks.TotalTokens)
36+
37+
// Output: total token count: 4
38+
}

vertexai/genai/tokenizer/tokenizer.go

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Package tokenizer provides local token counting for Gemini models. This
16+
// tokenizer downloads its model from the web, but otherwise doesn't require
17+
// an API call for every CountTokens invocation.
18+
package tokenizer
19+
20+
import (
21+
"bytes"
22+
"crypto/sha256"
23+
"encoding/hex"
24+
"fmt"
25+
"io"
26+
"net/http"
27+
"os"
28+
"path/filepath"
29+
30+
"cloud.google.com/go/vertexai/genai"
31+
"cloud.google.com/go/vertexai/internal/sentencepiece"
32+
)
33+
34+
var supportedModels = map[string]bool{
35+
"gemini-1.0-pro": true,
36+
"gemini-1.5-pro": true,
37+
"gemini-1.5-flash": true,
38+
"gemini-1.0-pro-001": true,
39+
"gemini-1.0-pro-002": true,
40+
"gemini-1.5-pro-001": true,
41+
"gemini-1.5-flash-001": true,
42+
}
43+
44+
// Tokenizer is a local tokenizer for text.
45+
type Tokenizer struct {
46+
encoder *sentencepiece.Encoder
47+
}
48+
49+
// CountTokensResponse is the response of [Tokenizer.CountTokens].
50+
type CountTokensResponse struct {
51+
TotalTokens int32
52+
}
53+
54+
// New creates a new [Tokenizer] from a model name; the model name is the same
55+
// as you would pass to a [genai.Client.GenerativeModel].
56+
func New(modelName string) (*Tokenizer, error) {
57+
if !supportedModels[modelName] {
58+
return nil, fmt.Errorf("model %s is not supported", modelName)
59+
}
60+
61+
data, err := loadModelData(gemmaModelURL, gemmaModelHash)
62+
if err != nil {
63+
return nil, fmt.Errorf("loading model: %w", err)
64+
}
65+
66+
encoder, err := sentencepiece.NewEncoder(bytes.NewReader(data))
67+
if err != nil {
68+
return nil, fmt.Errorf("creating encoder: %w", err)
69+
}
70+
71+
return &Tokenizer{encoder: encoder}, nil
72+
}
73+
74+
// CountTokens counts the tokens in all the given parts and returns their
75+
// sum. Only [genai.Text] parts are suppored; an error will be returned if
76+
// non-text parts are provided.
77+
func (tok *Tokenizer) CountTokens(parts ...genai.Part) (*CountTokensResponse, error) {
78+
sum := 0
79+
80+
for _, part := range parts {
81+
if t, ok := part.(genai.Text); ok {
82+
toks := tok.encoder.Encode(string(t))
83+
sum += len(toks)
84+
} else {
85+
return nil, fmt.Errorf("Tokenizer.CountTokens only supports Text parts")
86+
}
87+
}
88+
89+
return &CountTokensResponse{TotalTokens: int32(sum)}, nil
90+
}
91+
92+
// gemmaModelURL is the URL from which we download the model file.
93+
const gemmaModelURL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model"
94+
95+
// gemmaModelHash is the expected hash of the model file (as calculated
96+
// by [hashString]).
97+
const gemmaModelHash = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
98+
99+
// downloadModelFile downloads a file from the given URL.
100+
func downloadModelFile(url string) ([]byte, error) {
101+
resp, err := http.Get(url)
102+
if err != nil {
103+
return nil, err
104+
}
105+
defer resp.Body.Close()
106+
107+
return io.ReadAll(resp.Body)
108+
}
109+
110+
// hashString computes a hex string of the SHA256 hash of data.
111+
func hashString(data []byte) string {
112+
hash256 := sha256.Sum256(data)
113+
return hex.EncodeToString(hash256[:])
114+
}
115+
116+
// loadModelData loads model data from the given URL, using a local file-system
117+
// cache. wantHash is the hash (as returned by [hashString] expected on the
118+
// loaded data.
119+
//
120+
// Caching logic:
121+
//
122+
// Assuming $TEMP_DIR is the temporary directory used by the OS, this function
123+
// uses the file $TEMP_DIR/vertexai_tokenizer_model/$urlhash as a cache, where
124+
// $urlhash is hashString(url).
125+
//
126+
// If this cache file doesn't exist, or the data it contains doesn't match
127+
// wantHash, downloads data from the URL and writes it into the cache. If the
128+
// URL's data doesn't match the hash, an error is returned.
129+
func loadModelData(url string, wantHash string) ([]byte, error) {
130+
urlhash := hashString([]byte(url))
131+
cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model")
132+
cachePath := filepath.Join(cacheDir, urlhash)
133+
134+
cacheData, err := os.ReadFile(cachePath)
135+
if err != nil || hashString(cacheData) != wantHash {
136+
cacheData, err = downloadModelFile(url)
137+
if err != nil {
138+
return nil, fmt.Errorf("loading cache and downloading model: %w", err)
139+
}
140+
141+
if hashString(cacheData) != wantHash {
142+
return nil, fmt.Errorf("downloaded model hash mismatch")
143+
}
144+
145+
err = os.MkdirAll(cacheDir, 0770)
146+
if err != nil {
147+
return nil, fmt.Errorf("creating cache dir: %w", err)
148+
}
149+
err = os.WriteFile(cachePath, cacheData, 0660)
150+
if err != nil {
151+
return nil, fmt.Errorf("writing cache file: %w", err)
152+
}
153+
}
154+
155+
return cacheData, nil
156+
}
+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package tokenizer
15+
16+
import (
17+
"fmt"
18+
"os"
19+
"path/filepath"
20+
"testing"
21+
22+
"cloud.google.com/go/vertexai/genai"
23+
)
24+
25+
func TestDownload(t *testing.T) {
26+
b, err := downloadModelFile(gemmaModelURL)
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
31+
if hashString(b) != gemmaModelHash {
32+
t.Errorf("gemma model hash doesn't match")
33+
}
34+
}
35+
36+
func TestLoadModelData(t *testing.T) {
37+
// Tests that loadModelData manages to load the model properly, and download
38+
// a new one as needed.
39+
checkDataAndErr := func(data []byte, err error) {
40+
t.Helper()
41+
if err != nil {
42+
t.Error(err)
43+
}
44+
gotHash := hashString(data)
45+
if gotHash != gemmaModelHash {
46+
t.Errorf("got hash=%v, want=%v", gotHash, gemmaModelHash)
47+
}
48+
}
49+
50+
data, err := loadModelData(gemmaModelURL, gemmaModelHash)
51+
checkDataAndErr(data, err)
52+
53+
// The cache should exist now and have the right data, try again.
54+
data, err = loadModelData(gemmaModelURL, gemmaModelHash)
55+
checkDataAndErr(data, err)
56+
57+
// Overwrite cache file with wrong data, and try again.
58+
cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model")
59+
cachePath := filepath.Join(cacheDir, hashString([]byte(gemmaModelURL)))
60+
_ = os.MkdirAll(cacheDir, 0770)
61+
_ = os.WriteFile(cachePath, []byte{0, 1, 2, 3}, 0660)
62+
data, err = loadModelData(gemmaModelURL, gemmaModelHash)
63+
checkDataAndErr(data, err)
64+
}
65+
66+
func TestCreateTokenizer(t *testing.T) {
67+
// Create a tokenizer successfully
68+
_, err := New("gemini-1.5-flash")
69+
if err != nil {
70+
t.Error(err)
71+
}
72+
73+
// Create a tokenizer with an unsupported model
74+
_, err = New("gemini-0.92")
75+
if err == nil {
76+
t.Errorf("got no error, want error")
77+
}
78+
}
79+
80+
func TestCountTokens(t *testing.T) {
81+
var tests = []struct {
82+
parts []genai.Part
83+
wantCount int32
84+
}{
85+
{[]genai.Part{genai.Text("hello world")}, 2},
86+
{[]genai.Part{genai.Text("<table><th></th></table>")}, 4},
87+
{[]genai.Part{genai.Text("hello world"), genai.Text("<table><th></th></table>")}, 6},
88+
}
89+
90+
tok, err := New("gemini-1.5-flash")
91+
if err != nil {
92+
t.Error(err)
93+
}
94+
95+
for i, tt := range tests {
96+
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
97+
got, err := tok.CountTokens(tt.parts...)
98+
if err != nil {
99+
t.Error(err)
100+
}
101+
if got.TotalTokens != tt.wantCount {
102+
t.Errorf("got %v, want %v", got.TotalTokens, tt.wantCount)
103+
}
104+
})
105+
}
106+
}
107+
108+
func TestCountTokensNonText(t *testing.T) {
109+
tok, err := New("gemini-1.5-flash")
110+
if err != nil {
111+
t.Error(err)
112+
}
113+
114+
_, err = tok.CountTokens(genai.Text("foo"), genai.ImageData("format", []byte{0, 1}))
115+
if err == nil {
116+
t.Error("got no error, want error")
117+
}
118+
}

0 commit comments

Comments
 (0)