Skip to content

Commit 158a5e2

Browse files
committed
remove model name dependency from pool
Signed-off-by: Sage Ahrac <[email protected]>
1 parent 0821cff commit 158a5e2

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

pkg/tokenization/pool.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func (pool *Pool) processTask(task Task) error {
205205
}
206206
}
207207

208-
tokenIDs, overlapRatio := pool.indexer.FindLongestContainedTokens(task.Prompt, task.ModelName)
208+
tokenIDs, overlapRatio := pool.indexer.FindLongestContainedTokens(task.Prompt)
209209

210210
// if the overlap ratio is low, get the full tokenization
211211
if overlapRatio < pool.minPrefixOverlapRatio {
@@ -216,7 +216,7 @@ func (pool *Pool) processTask(task Task) error {
216216
}
217217

218218
// update the indexer with the new tokenization
219-
if e := pool.indexer.AddTokenization(task.ModelName, task.Prompt, tokens, offsets); e != nil {
219+
if e := pool.indexer.AddTokenization(task.Prompt, tokens, offsets); e != nil {
220220
err = fmt.Errorf("tokenization failed for model %s: %w", task.ModelName, e)
221221
return err
222222
}

pkg/tokenization/pool_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ type MockIndexer struct {
6767
mock.Mock
6868
}
6969

70-
func (m *MockIndexer) AddTokenization(modelName, prompt string, tokens []uint32, offsets []tokenizers.Offset) error {
71-
args := m.Called(modelName, prompt, tokens, offsets)
70+
func (m *MockIndexer) AddTokenization(prompt string, tokens []uint32, offsets []tokenizers.Offset) error {
71+
args := m.Called(prompt, tokens, offsets)
7272
return args.Error(0)
7373
}
7474

7575
//nolint:gocritic // unnamedResult: tokens and overlapRatio are self-explanatory from context
76-
func (m *MockIndexer) FindLongestContainedTokens(prompt, modelName string) ([]uint32, float64) {
77-
args := m.Called(prompt, modelName)
76+
func (m *MockIndexer) FindLongestContainedTokens(prompt string) ([]uint32, float64) {
77+
args := m.Called(prompt)
7878
tokens := args.Get(0).([]uint32) //nolint:errcheck // unused mock
7979
return tokens, 0.0
8080
}

0 commit comments

Comments
 (0)