@@ -20,7 +20,7 @@ public sealed class Tiktoken : Model
2020 {
2121 private readonly Dictionary < ReadOnlyMemory < byte > , int > _encoder = null ! ;
2222 private readonly IReadOnlyDictionary < int , byte [ ] > _decoder = null ! ;
23- private readonly LruCache < string , int [ ] > _cache ;
23+ private readonly LruCache < string , int [ ] > ? _cache ;
2424 private readonly IReadOnlyDictionary < string , int > ? _specialTokensEncoder ;
2525 private readonly Dictionary < int , string > ? _specialTokensDecoder ;
2626 private readonly Dictionary < string , int > _vocab = null ! ;
@@ -96,7 +96,14 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>?
9696
9797 private Tiktoken ( int cacheSize )
9898 {
99- _cache = new LruCache < string , int [ ] > ( cacheSize ) ;
99+ if ( cacheSize < 0 )
100+ {
101+ throw new ArgumentOutOfRangeException ( nameof ( cacheSize ) ) ;
102+ }
103+ else if ( cacheSize > 0 )
104+ {
105+ _cache = new LruCache < string , int [ ] > ( cacheSize ) ;
106+ }
100107 }
101108
102109 /// <summary>
@@ -198,7 +205,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
198205 throw new InvalidOperationException ( $ "The special token { sequence } doesn't exist in the tokenizer") ;
199206 }
200207
201- if ( _cache . Lookup ( sequence , out int [ ] ids ) )
208+ if ( _cache ? . Lookup ( sequence , out int [ ] ids ) is true )
202209 {
203210 tokens = new Token [ ids . Length ] ;
204211 tokens [ 0 ] = new Token ( ids [ 0 ] , sequence , ( 0 , sequence . Length ) ) ;
@@ -222,7 +229,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
222229
223230 int [ ] encodedIds = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
224231 Debug . Assert ( encodedIds . Length > 0 ) ;
225- _cache . Add ( sequence , encodedIds ) ;
232+ _cache ? . Add ( sequence , encodedIds ) ;
226233
227234 tokens = new Token [ encodedIds . Length ] ;
228235 tokens [ 0 ] = new Token ( encodedIds [ 0 ] , sequence , ( 0 , sequence . Length ) ) ;
@@ -259,7 +266,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<i
259266 return ;
260267 }
261268
262- if ( _cache . Lookup ( sequence , out int [ ] tokenIds ) )
269+ if ( _cache ? . Lookup ( sequence , out int [ ] tokenIds ) is true )
263270 {
264271 accumulatedIds . AddRange ( tokenIds ) ;
265272 return ;
@@ -275,7 +282,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<i
275282 int encodedLength = GetUtf8Bytes ( sequence . AsSpan ( ) , arrayPoolArray ) ;
276283
277284 int [ ] encodedIds = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
278- _cache . Add ( sequence , encodedIds ) ;
285+ _cache ? . Add ( sequence , encodedIds ) ;
279286
280287 accumulatedIds . AddRange ( encodedIds ) ;
281288
@@ -301,7 +308,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
301308 return _specialTokensEncoder . TryGetValue ( sequence , out _ ) ? 1 : 0 ;
302309 }
303310
304- if ( _cache . Lookup ( sequence , out int [ ] ids ) )
311+ if ( _cache ? . Lookup ( sequence , out int [ ] ids ) is true )
305312 {
306313 return ids . Length ;
307314 }
@@ -315,7 +322,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
315322 int encodedLength = GetUtf8Bytes ( sequence . AsSpan ( ) , arrayPoolArray ) ;
316323
317324 int [ ] encodedIds = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
318- _cache . Add ( sequence , encodedIds ) ;
325+ _cache ? . Add ( sequence , encodedIds ) ;
319326
320327 ArrayPool < byte > . Shared . Return ( arrayPoolArray ) ;
321328 return encodedIds . Length ;
@@ -346,7 +353,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
346353 return specialTokenId ;
347354 }
348355
349- if ( _cache . Lookup ( token , out int [ ] ids ) )
356+ if ( _cache ? . Lookup ( token , out int [ ] ids ) is true )
350357 {
351358 if ( ids . Length == 1 )
352359 {
@@ -367,7 +374,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
367374 int encodedLength = GetUtf8Bytes ( token . AsSpan ( ) , arrayPoolArray ) ;
368375
369376 int [ ] idsToCache = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
370- _cache . Add ( token , idsToCache ) ;
377+ _cache ? . Add ( token , idsToCache ) ;
371378
372379 if ( idsToCache . Length == 1 )
373380 {
0 commit comments