11// This check is new and seems buggy (possibly with PyO3 interaction)
22#![ allow( clippy:: borrow_deref_ref) ]
33
4- use std:: collections:: HashSet ;
4+ use std:: collections:: { BTreeMap , BTreeSet , HashSet } ;
5+ use std:: iter:: successors;
56use std:: num:: NonZeroU64 ;
67use std:: thread;
78
@@ -15,9 +16,22 @@ use rustc_hash::FxHashMap as HashMap;
1516
1617type Rank = u32 ;
1718
19+ const LARGE_ENCODER_CHARACTER_LIMIT : usize = 500 ;
20+
1821fn _byte_pair_merge (
1922 ranks : & HashMap < Vec < u8 > , Rank > ,
2023 piece : & [ u8 ] ,
24+ ) -> Vec < ( usize , Rank ) > {
25+ if piece. len ( ) < LARGE_ENCODER_CHARACTER_LIMIT {
26+ _byte_pair_merge_small ( ranks, piece) // Quadratic, but lightweight
27+ } else {
28+ _byte_pair_merge_large ( ranks, piece) // Linearithmic, but heavy
29+ }
30+ }
31+
32+ fn _byte_pair_merge_small (
33+ ranks : & HashMap < Vec < u8 > , Rank > ,
34+ piece : & [ u8 ] ,
2135) -> Vec < ( usize , Rank ) > {
2236 let get_rank = |parts : & Vec < ( usize , _ ) > , start_idx : usize , end_idx : usize | {
2337 * parts. get ( end_idx)
@@ -56,6 +70,85 @@ fn _byte_pair_merge(
5670 parts
5771}
5872
73+ fn _byte_pair_merge_large (
74+ ranks : & HashMap < Vec < u8 > , Rank > ,
75+ piece : & [ u8 ] ,
76+ ) -> Vec < ( usize , Rank ) > {
77+ let get_rank = |start_idx : usize , end_idx : usize | {
78+ * piece. get ( start_idx..end_idx)
79+ . filter ( |p| p. len ( ) < piece. len ( ) )
80+ . and_then ( |p| ranks. get ( p) )
81+ . unwrap_or ( & Rank :: MAX )
82+ } ;
83+
84+ let mut rank_indexes = BTreeMap :: < Rank , BTreeSet < usize > > :: new ( ) ;
85+ let mut index_rank = vec ! [ Rank :: MAX ; piece. len( ) + 1 ] ;
86+ let mut index_prev = vec ! [ usize :: MAX ; piece. len( ) + 1 ] ;
87+ let mut index_next = vec ! [ usize :: MAX ; piece. len( ) + 1 ] ;
88+
89+ let mut prev_node = None ;
90+ for i in 0 ..=piece. len ( ) {
91+ let rank = get_rank ( i, i + 2 ) ;
92+ index_rank[ i] = rank;
93+ if let Some ( prev) = prev_node {
94+ index_prev[ i] = prev;
95+ index_next[ prev] = i;
96+ }
97+ prev_node = Some ( i) ;
98+
99+ rank_indexes. entry ( rank) . or_default ( ) . insert ( i) ;
100+ }
101+
102+ let mut token_count = piece. len ( ) ;
103+ while token_count > 2 && rank_indexes. len ( ) > 1 {
104+ let mut skip_next = false ;
105+ if let Some ( ( _, nodes) ) = rank_indexes. pop_first ( ) {
106+ for & min_node in & nodes {
107+ if skip_next {
108+ skip_next = false ;
109+ continue ;
110+ }
111+
112+ let min_rank = index_rank[ min_node] ;
113+
114+ let prev_node = index_prev[ min_node] ;
115+ let next_node = index_next[ min_node] ;
116+ let next_next_node = index_next[ next_node] ;
117+ let next_next_next_node = index_next[ next_next_node] ;
118+
119+ if prev_node != usize:: MAX {
120+ let new_rank = get_rank ( prev_node, next_next_node) ;
121+ if index_rank[ prev_node] != new_rank {
122+ rank_indexes. get_mut ( & index_rank[ prev_node] ) . unwrap ( ) . remove ( & prev_node) ;
123+ index_rank[ prev_node] = new_rank;
124+ rank_indexes. entry ( new_rank) . or_default ( ) . insert ( prev_node) ;
125+ }
126+ }
127+
128+ let new_rank = get_rank ( min_node, next_next_next_node) ;
129+ index_rank[ min_node] = new_rank;
130+ rank_indexes. entry ( new_rank) . or_default ( ) . insert ( min_node) ;
131+
132+ index_next[ min_node] = next_next_node;
133+ index_prev[ next_next_node] = min_node;
134+ let next_node_rank = index_rank[ next_node] ;
135+ if next_node_rank != Rank :: MAX {
136+ skip_next = next_node_rank == min_rank;
137+ if !skip_next {
138+ rank_indexes. get_mut ( & next_node_rank) . unwrap ( ) . remove ( & next_node) ;
139+ }
140+ }
141+
142+ token_count -= 1 ;
143+ }
144+ }
145+ }
146+
147+ successors ( Some ( 0 ) , |& n| index_next. get ( n) . filter ( |& x| * x != usize:: MAX ) . copied ( ) )
148+ . map ( |n| ( n, Rank :: MAX ) )
149+ . collect ( )
150+ }
151+
59152pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
60153 assert ! ( piece. len( ) > 1 ) ;
61154 _byte_pair_merge ( & ranks, & piece)
0 commit comments