@@ -8,17 +8,17 @@ use std::thread;
88use fancy_regex:: Regex ;
99use pyo3:: exceptions;
1010use pyo3:: prelude:: * ;
11+ use pyo3:: pyclass;
1112use pyo3:: PyResult ;
1213use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
1314use rustc_hash:: FxHashMap as HashMap ;
1415
1516type Rank = u32 ;
1617
17- fn _byte_pair_merge < T > (
18- piece : & [ u8 ] ,
18+ fn _byte_pair_merge (
1919 ranks : & HashMap < Vec < u8 > , Rank > ,
20- f : impl Fn ( std :: ops :: Range < usize > ) -> T ,
21- ) -> Vec < T > {
20+ piece : & [ u8 ] ,
21+ ) -> Vec < ( usize , Rank ) > {
2222 // This is a vector of (start, rank).
2323 // The rank is of the byte pair starting at position start.
2424 // The rank of the last item in the vector is not a valid value.
@@ -93,25 +93,24 @@ fn _byte_pair_merge<T>(
9393 break ;
9494 }
9595 }
96- let mut out: Vec < T > = Vec :: with_capacity ( parts. len ( ) - 1 ) ;
97- for i in 0 ..parts. len ( ) - 1 {
98- out. push ( f ( parts[ i] . 0 ..parts[ i + 1 ] . 0 ) ) ;
99- }
100- out
96+
97+ parts
10198}
10299
103100pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
104- if piece. len ( ) == 1 {
105- return vec ! [ ranks[ piece] ] ;
106- }
107- _byte_pair_merge ( piece, ranks, |p| ranks[ & piece[ p. start ..p. end ] ] )
101+ assert ! ( piece. len( ) > 1 ) ;
102+ _byte_pair_merge ( & ranks, & piece)
103+ . windows ( 2 )
104+ . map ( |part| ranks[ & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] ] )
105+ . collect ( )
108106}
109107
110108pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
111- if piece. len ( ) == 1 {
112- return vec ! [ piece] ;
113- }
114- _byte_pair_merge ( piece, ranks, |p| & piece[ p. start ..p. end ] )
109+ assert ! ( piece. len( ) > 1 ) ;
110+ _byte_pair_merge ( & ranks, & piece)
111+ . windows ( 2 )
112+ . map ( |part| & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] )
113+ . collect ( )
115114}
116115
117116// Various performance notes:
0 commit comments