@@ -162,10 +162,10 @@ fn hash_current_thread() -> usize {
162162 // that works great for our use case of avoiding collisions in our array. Unfortunately,
163163 // it's private. However, there are only so many ways you can layout a u64, so just transmute
164164 // https://github.com/rust-lang/rust/issues/67939
165- const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < std :: thread:: ThreadId > ( ) ] ;
165+ const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < thread:: ThreadId > ( ) ] ;
166166 const _: [ u8 ; 8 ] = [ 0 ; std:: mem:: size_of :: < FakeThreadId > ( ) ] ;
167167 let x = unsafe {
168- std:: mem:: transmute :: < std :: thread:: ThreadId , FakeThreadId > ( thread:: current ( ) . id ( ) ) . 0
168+ std:: mem:: transmute :: < thread:: ThreadId , FakeThreadId > ( thread:: current ( ) . id ( ) ) . 0
169169 } ;
170170 u64:: from ( x) as usize
171171}
@@ -214,11 +214,10 @@ impl CoreBPE {
214214 let mut ret = vec ! [ ] ;
215215 for mat in regex. find_iter ( text) {
216216 let piece = mat. unwrap ( ) . as_str ( ) . as_bytes ( ) ;
217- if let Some ( token ) = self . encoder . get ( piece) {
218- ret. push ( * token) ;
219- continue ;
217+ match self . encoder . get ( piece) {
218+ Some ( token ) => ret. push ( * token) ,
219+ None => ret . extend ( & byte_pair_encode ( piece , & self . encoder ) ) ,
220220 }
221- ret. extend ( & byte_pair_encode ( piece, & self . encoder ) ) ;
222221 }
223222 ret
224223 }
@@ -516,7 +515,10 @@ impl CoreBPE {
516515 unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
517516
518517 tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
519- tokens. extend ( byte_pair_encode ( & unstable_bytes, & self . encoder ) ) ;
518+ match self . encoder . get ( & unstable_bytes) {
519+ Some ( token) => tokens. push ( * token) ,
520+ None => tokens. extend ( & byte_pair_encode ( & unstable_bytes, & self . encoder ) ) ,
521+ }
520522 }
521523 tokens
522524 }
@@ -597,15 +599,26 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
597599mod tests {
598600 use rustc_hash:: FxHashMap as HashMap ;
599601
600- use crate :: byte_pair_split;
602+ use crate :: { byte_pair_split, Rank } ;
601603
602- #[ test]
603- fn very_simple_test ( ) {
604- let mut ranks = HashMap :: default ( ) ;
605- ranks. insert ( b"ab" . to_vec ( ) , 1 ) ;
606- ranks. insert ( b"cd" . to_vec ( ) , 2 ) ;
604+ fn setup_ranks ( ) -> HashMap < Vec < u8 > , Rank > {
605+ HashMap :: from_iter ( [
606+ ( b"ab" . to_vec ( ) , 0 ) ,
607+ ( b"cd" . to_vec ( ) , 1 ) ,
608+ ] )
609+ }
607610
611+ #[ test]
612+ fn test_simple_characters ( ) {
613+ let ranks = setup_ranks ( ) ;
608614 let res = byte_pair_split ( b"abcd" , & ranks) ;
609615 assert_eq ! ( res, vec![ b"ab" , b"cd" ] ) ;
610616 }
617+
618+ #[ test]
619+ fn test_repeated_characters ( ) {
620+ let ranks = setup_ranks ( ) ;
621+ let res = byte_pair_split ( b"abab" , & ranks) ;
622+ assert_eq ! ( res, vec![ b"ab" , b"ab" ] ) ;
623+ }
611624}
0 commit comments