@@ -19,78 +19,37 @@ fn _byte_pair_merge(
1919 ranks : & HashMap < Vec < u8 > , Rank > ,
2020 piece : & [ u8 ] ,
2121) -> Vec < ( usize , Rank ) > {
22- // This is a vector of (start, rank).
23- // The rank is of the byte pair starting at position start.
24- // The rank of the last item in the vector is not a valid value.
25- let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
26-
27- let get_rank = {
28- #[ inline( always) ]
29- |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
30- if ( start_idx + skip + 2 ) < parts. len ( ) {
31- ranks
32- . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
33- . copied ( )
34- } else {
35- None
36- }
37- }
22+ let get_rank = |parts : & Vec < ( usize , _ ) > , start_idx : usize , end_idx : usize | {
23+ * parts. get ( end_idx)
24+ . map ( |e| parts. get ( start_idx) . unwrap ( ) . 0 ..e. 0 )
25+ . and_then ( |r| piece. get ( r) )
26+ . filter ( |p| p. len ( ) < piece. len ( ) )
27+ . and_then ( |p| ranks. get ( p) )
28+ . unwrap_or ( & Rank :: MAX )
3829 } ;
3930
40- // We look up the ranks once in the beginning and iteratively update
41- // them during each merge, which reduces the number of rank lookups.
42- for i in 0 ..parts. len ( ) - 2 {
43- match get_rank ( & parts, i, 0 ) {
44- Some ( rank) => {
45- // Rank::MAX is a sentinel value and cannot be a valid rank
46- debug_assert ! ( rank != Rank :: MAX ) ;
47- parts[ i] . 1 = rank;
48- }
49- None => {
50- continue ;
51- }
52- } ;
53- }
54-
55- // If you have n parts and m merges, this does O(mn) work.
56- // We could do something with a heap and do O(m log n) work.
57- // It is important to consider that n is often small (<100), and as such
58- // the cache-locality benefits outweigh the algorithmic complexity downsides
59- // of the `parts` vector data structure above.
60-
61- // Note that we hash bytes, not token pairs. As long as we train BPE the way we
62- // currently do, this is equivalent. An easy way to break this would be to decouple
63- // merge priority from token index or to prevent specific token merges.
64- loop {
65- if parts. len ( ) == 1 {
66- break ;
31+ let ( mut min_rank_index, mut min_rank) = ( 0 , Rank :: MAX ) ;
32+ let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
33+ for i in 0 ..piece. len ( ) + 1 {
34+ let part = ( i, * piece. get ( i..i + 2 ) . and_then ( |p| ranks. get ( p) ) . unwrap_or ( & Rank :: MAX ) ) ;
35+ if part. 1 < min_rank {
36+ ( min_rank_index, min_rank) = part;
6737 }
38+ parts. push ( part) ;
39+ }
6840
69- // Rank::MAX is a sentinel rank value allowing us to
70- // take the min more quickly
71- let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
72- for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
73- if rank < min_rank. 0 {
74- min_rank = ( rank, i) ;
75- }
41+ while parts. len ( ) > 3 && min_rank != Rank :: MAX {
42+ if min_rank_index > 0 {
43+ parts[ min_rank_index - 1 ] . 1 = get_rank ( & parts, min_rank_index - 1 , min_rank_index + 2 ) ;
7644 }
45+ parts[ min_rank_index] . 1 = get_rank ( & parts, min_rank_index, min_rank_index + 3 ) ;
46+ parts. remove ( min_rank_index + 1 ) ;
7747
78- if min_rank. 0 != Rank :: MAX {
79- let i = min_rank. 1 ;
80-
81- // NOTE: We are about to remove parts[i + 1]. We do not do it
82- // yet because there are cache-locality benefits to updating
83- // parts[i] and parts[i-1] before removing, which could thrash
84- // the cache. Thus, we update the rank calculation by skipping over
85- // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
86- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
87- if i > 0 {
88- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
48+ ( min_rank_index, min_rank) = ( 0 , parts[ 0 ] . 1 ) ;
49+ for i in 1 ..parts. len ( ) - 2 {
50+ if parts[ i] . 1 < min_rank {
51+ ( min_rank_index, min_rank) = ( i, parts[ i] . 1 ) ;
8952 }
90-
91- parts. remove ( i + 1 ) ;
92- } else {
93- break ;
9453 }
9554 }
9655
0 commit comments