|
| 1 | +// Copyright 2024 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +// Package tokenizer provides local token counting for Gemini models. This |
| 16 | +// tokenizer downloads its model from the web, but otherwise doesn't require |
| 17 | +// an API call for every CountTokens invocation. |
| 18 | +package tokenizer |
| 19 | + |
| 20 | +import ( |
| 21 | + "bytes" |
| 22 | + "crypto/sha256" |
| 23 | + "encoding/hex" |
| 24 | + "fmt" |
| 25 | + "io" |
| 26 | + "net/http" |
| 27 | + "os" |
| 28 | + "path/filepath" |
| 29 | + |
| 30 | + "cloud.google.com/go/vertexai/genai" |
| 31 | + "cloud.google.com/go/vertexai/internal/sentencepiece" |
| 32 | +) |
| 33 | + |
| 34 | +var supportedModels = map[string]bool{ |
| 35 | + "gemini-1.0-pro": true, |
| 36 | + "gemini-1.5-pro": true, |
| 37 | + "gemini-1.5-flash": true, |
| 38 | + "gemini-1.0-pro-001": true, |
| 39 | + "gemini-1.0-pro-002": true, |
| 40 | + "gemini-1.5-pro-001": true, |
| 41 | + "gemini-1.5-flash-001": true, |
| 42 | +} |
| 43 | + |
| 44 | +// Tokenizer is a local tokenizer for text. |
| 45 | +type Tokenizer struct { |
| 46 | + encoder *sentencepiece.Encoder |
| 47 | +} |
| 48 | + |
| 49 | +// CountTokensResponse is the response of [Tokenizer.CountTokens]. |
| 50 | +type CountTokensResponse struct { |
| 51 | + TotalTokens int32 |
| 52 | +} |
| 53 | + |
| 54 | +// New creates a new [Tokenizer] from a model name; the model name is the same |
| 55 | +// as you would pass to a [genai.Client.GenerativeModel]. |
| 56 | +func New(modelName string) (*Tokenizer, error) { |
| 57 | + if !supportedModels[modelName] { |
| 58 | + return nil, fmt.Errorf("model %s is not supported", modelName) |
| 59 | + } |
| 60 | + |
| 61 | + data, err := loadModelData(gemmaModelURL, gemmaModelHash) |
| 62 | + if err != nil { |
| 63 | + return nil, fmt.Errorf("loading model: %w", err) |
| 64 | + } |
| 65 | + |
| 66 | + encoder, err := sentencepiece.NewEncoder(bytes.NewReader(data)) |
| 67 | + if err != nil { |
| 68 | + return nil, fmt.Errorf("creating encoder: %w", err) |
| 69 | + } |
| 70 | + |
| 71 | + return &Tokenizer{encoder: encoder}, nil |
| 72 | +} |
| 73 | + |
| 74 | +// CountTokens counts the tokens in all the given parts and returns their |
| 75 | +// sum. Only [genai.Text] parts are suppored; an error will be returned if |
| 76 | +// non-text parts are provided. |
| 77 | +func (tok *Tokenizer) CountTokens(parts ...genai.Part) (*CountTokensResponse, error) { |
| 78 | + sum := 0 |
| 79 | + |
| 80 | + for _, part := range parts { |
| 81 | + if t, ok := part.(genai.Text); ok { |
| 82 | + toks := tok.encoder.Encode(string(t)) |
| 83 | + sum += len(toks) |
| 84 | + } else { |
| 85 | + return nil, fmt.Errorf("Tokenizer.CountTokens only supports Text parts") |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + return &CountTokensResponse{TotalTokens: int32(sum)}, nil |
| 90 | +} |
| 91 | + |
| 92 | +// gemmaModelURL is the URL from which we download the model file. |
| 93 | +const gemmaModelURL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model" |
| 94 | + |
| 95 | +// gemmaModelHash is the expected hash of the model file (as calculated |
| 96 | +// by [hashString]). |
| 97 | +const gemmaModelHash = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2" |
| 98 | + |
| 99 | +// downloadModelFile downloads a file from the given URL. |
| 100 | +func downloadModelFile(url string) ([]byte, error) { |
| 101 | + resp, err := http.Get(url) |
| 102 | + if err != nil { |
| 103 | + return nil, err |
| 104 | + } |
| 105 | + defer resp.Body.Close() |
| 106 | + |
| 107 | + return io.ReadAll(resp.Body) |
| 108 | +} |
| 109 | + |
| 110 | +// hashString computes a hex string of the SHA256 hash of data. |
| 111 | +func hashString(data []byte) string { |
| 112 | + hash256 := sha256.Sum256(data) |
| 113 | + return hex.EncodeToString(hash256[:]) |
| 114 | +} |
| 115 | + |
| 116 | +// loadModelData loads model data from the given URL, using a local file-system |
| 117 | +// cache. wantHash is the hash (as returned by [hashString] expected on the |
| 118 | +// loaded data. |
| 119 | +// |
| 120 | +// Caching logic: |
| 121 | +// |
| 122 | +// Assuming $TEMP_DIR is the temporary directory used by the OS, this function |
| 123 | +// uses the file $TEMP_DIR/vertexai_tokenizer_model/$urlhash as a cache, where |
| 124 | +// $urlhash is hashString(url). |
| 125 | +// |
| 126 | +// If this cache file doesn't exist, or the data it contains doesn't match |
| 127 | +// wantHash, downloads data from the URL and writes it into the cache. If the |
| 128 | +// URL's data doesn't match the hash, an error is returned. |
| 129 | +func loadModelData(url string, wantHash string) ([]byte, error) { |
| 130 | + urlhash := hashString([]byte(url)) |
| 131 | + cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model") |
| 132 | + cachePath := filepath.Join(cacheDir, urlhash) |
| 133 | + |
| 134 | + cacheData, err := os.ReadFile(cachePath) |
| 135 | + if err != nil || hashString(cacheData) != wantHash { |
| 136 | + cacheData, err = downloadModelFile(url) |
| 137 | + if err != nil { |
| 138 | + return nil, fmt.Errorf("loading cache and downloading model: %w", err) |
| 139 | + } |
| 140 | + |
| 141 | + if hashString(cacheData) != wantHash { |
| 142 | + return nil, fmt.Errorf("downloaded model hash mismatch") |
| 143 | + } |
| 144 | + |
| 145 | + err = os.MkdirAll(cacheDir, 0770) |
| 146 | + if err != nil { |
| 147 | + return nil, fmt.Errorf("creating cache dir: %w", err) |
| 148 | + } |
| 149 | + err = os.WriteFile(cachePath, cacheData, 0660) |
| 150 | + if err != nil { |
| 151 | + return nil, fmt.Errorf("writing cache file: %w", err) |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + return cacheData, nil |
| 156 | +} |
0 commit comments