diff --git a/Cargo.lock b/Cargo.lock index 8f40d0b..da9149a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -972,7 +972,7 @@ dependencies = [ [[package]] name = "encoder" version = "0.1.0" -source = "git+https://github.com/scroll-tech/da-codec.git?tag=v0.1.2#8c5d2f0cd707153151a5154fef702204f6ca40b3" +source = "git+https://github.com/scroll-tech/da-codec.git?branch=test%2Fprim_zstd#ad2db39b5e1ef4532f98ad0de0cc4c4f5f059308" dependencies = [ "zstd", ] @@ -4805,8 +4805,9 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.0" -source = "git+https://github.com/scroll-tech/zstd-rs?branch=hack%2Fmul-block#5c0892b6567dab31394d701477183ce9d6a32aca" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" dependencies = [ "zstd-safe", ] @@ -4821,16 +4822,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.0.0" -source = "git+https://github.com/scroll-tech/zstd-rs?branch=hack%2Fmul-block#5c0892b6567dab31394d701477183ce9d6a32aca" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" -source = "git+https://github.com/scroll-tech/zstd-rs?branch=hack%2Fmul-block#5c0892b6567dab31394d701477183ce9d6a32aca" +version = "2.0.15+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 6f75e20..7e02d64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,14 +15,14 @@ strum = "0.25" strum_macros = "0.25" anyhow = "1" serde = { version = "1", default-features = false, features = ["derive"] } -zstd-encoder = { package = "encoder", git = "https://github.com/scroll-tech/da-codec.git", tag = "v0.1.2", optional = true } +zstd-encoder = { package = "encoder", git = "https://github.com/scroll-tech/da-codec.git", branch = "test/prim_zstd", optional = true } [dev-dependencies] hex = "0.4" openvm = { git = "https://github.com/openvm-org/openvm.git", default-features = false, features = ["std"] } openvm-sdk = { git = "https://github.com/openvm-org/openvm.git", default-features = false} openvm-transpiler = { git = "https://github.com/openvm-org/openvm.git", default-features = false } -zstd-encoder = { package = "encoder", git = "https://github.com/scroll-tech/da-codec.git", tag = "v0.1.2" } +zstd-encoder = { package = "encoder", git = "https://github.com/scroll-tech/da-codec.git", branch = "test/prim_zstd" } [features] zstd = ["dep:zstd-encoder"] @@ -30,3 +30,4 @@ zstd = ["dep:zstd-encoder"] [[bench]] harness = false name = "cycles" +required-features = ["zstd"] \ No newline at end of file diff --git a/src/decoding.rs b/src/decoding.rs index f77db94..924cfff 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -63,10 +63,9 @@ fn process_frame_header( encoded_len: last_state.encoded_data.encoded_len, }, decoded_data: last_state.decoded_data, - bitstream_read_data: None, - fse_data: None, literal_data: Vec::new(), repeated_offset: last_state.repeated_offset, + last_fse_table: last_state.last_fse_table, }, )) } @@ -161,10 +160,9 @@ fn process_block_header( encoded_len: last_state.encoded_data.encoded_len, }, literal_data: Vec::new(), - bitstream_read_data: None, decoded_data: last_state.decoded_data, - fse_data: None, repeated_offset: last_state.repeated_offset, + last_fse_table: last_state.last_fse_table, }, block_info, )) @@ -209,92 +207,13 @@ fn process_block_zstd( }, literal_data, decoded_data: last_state.decoded_data, - fse_data: None, repeated_offset: last_state.repeated_offset, - bitstream_read_data: None, + last_fse_table: last_state.last_fse_table, }; - // println!("offset after literal body {} of block {}", last_state.encoded_data.byte_idx, block_idx); - // let LiteralsBlockResult { - // offset: byte_offset, - // witness_rows: rows, - // literals, - // } = { - // let last_row = rows.last().cloned().unwrap(); - // let multiplier = - // (0..last_row.state.tag_len).fold(Value::known(F::one()), |acc, _| acc * randomness); - // let value_rlc = last_row.encoded_data.value_rlc * multiplier + last_row.state.tag_rlc; - // let tag = ZstdTag::ZstdBlockLiteralsRawBytes; - // let tag_next = ZstdTag::ZstdBlockSequenceHeader; - // let literals = src[byte_offset..(byte_offset + regen_size)].to_vec(); - // let tag_rlc_iter = literals.iter().scan(Value::known(F::zero()), |acc, &byte| { - // *acc = *acc * randomness + Value::known(F::from(byte as u64)); - // Some(*acc) - // }); - // let tag_rlc = tag_rlc_iter.clone().last().expect("Literals must exist."); - - // LiteralsBlockResult { - // offset: byte_offset + regen_size, - // witness_rows: literals - // .iter() - // .zip(tag_rlc_iter) - // .enumerate() - // .map(|(i, (&value_byte, tag_rlc_acc))| ZstdWitnessRow { - // state: ZstdState { - // tag, - // tag_next, - // block_idx, - // max_tag_len: tag.max_len(), - // tag_len: regen_size as u64, - // tag_idx: (i + 1) as u64, - // is_tag_change: i == 0, - // tag_rlc, - // tag_rlc_acc, - // }, - // encoded_data: EncodedData { - // byte_idx: (byte_offset + i + 1) as u64, - // encoded_len: last_row.encoded_data.encoded_len, - // value_byte, - // value_rlc, - // reverse: false, - // ..Default::default() - // }, - // decoded_data: DecodedData { - // decoded_len: last_row.decoded_data.decoded_len, - // }, - // bitstream_read_data: BitstreamReadRow::default(), - // fse_data: FseDecodingRow::default(), - // }) - // .collect::>(), - // literals: literals.iter().map(|b| *b as u64).collect::>(), - // } - // }; - let last_state = process_sequences(src, block_idx, expected_end_offset, last_state, last_block)?; - // let SequencesProcessingResult { - // offset, - // witness_rows: rows, - // fse_aux_tables, - // address_table_rows, - // original_bytes, - // sequence_info, - // sequence_exec, - // repeated_offset, - // } = process_sequences::( - // src, - // decoded_bytes, - // block_idx, - // byte_offset, - // expected_end_offset, - // literals.clone(), - // last_row, - // last_block, - // randomness, - // repeated_offset, - // ); - // sanity check: assert_eq!( last_state.encoded_data.byte_idx as usize, expected_end_offset, @@ -324,7 +243,7 @@ fn process_sequences( let byte0 = src[0]; assert!(byte0 > 0u8, "Sequences can't be of 0 length"); - let (_num_of_sequences, num_sequence_header_bytes) = if byte0 < 128 { + let (num_of_sequences, num_sequence_header_bytes) = if byte0 < 128 { (byte0 as u64, 2usize) } else { assert!(src.len() >= 2, "Next byte of sequence header must exist."); @@ -338,6 +257,39 @@ fn process_sequences( } }; + if num_of_sequences == 0 { + // this case is highly unpossible: we use raw literal so such a block would be + // definitely turned into a raw block + let mut literal_bytes = last_state + .literal_data + .into_iter() + .map(|v| v as u8) + .collect::>(); + let mut decoded_data = last_state.decoded_data; + decoded_data.append(&mut literal_bytes); + return Ok(ZstdDecodingState { + state: ZstdState { + tag: ZstdTag::ZstdBlockSequenceHeader, + tag_next: if last_block { + ZstdTag::Null + } else { + ZstdTag::BlockHeader + }, + block_idx, + max_tag_len: ZstdTag::ZstdBlockSequenceHeader.max_len(), + tag_len: num_sequence_header_bytes as u64, + }, + encoded_data: EncodedDataCursor { + byte_idx: byte_offset + num_sequence_header_bytes as u64, + encoded_len: last_state.encoded_data.encoded_len, + }, + decoded_data, + literal_data: Vec::new(), + repeated_offset: last_state.repeated_offset, + last_fse_table: last_state.last_fse_table, + }); + } + assert!( src.len() >= num_sequence_header_bytes, "Compression mode byte must exist." @@ -352,27 +304,6 @@ fn process_sequences( assert!(reserved == 0, "Reserved bits must be 0"); - // Note: Only 2 modes of FSE encoding are accepted (instead of 4): - // 0 - Predefined. - // 2 - Variable bit packing. - assert!( - literal_lengths_mode == 2 || literal_lengths_mode == 0, - "Only FSE_Compressed_Mode or Predefined are allowed" - ); - assert!( - offsets_mode == 2 || offsets_mode == 0, - "Only FSE_Compressed_Mode or Predefined are allowed" - ); - assert!( - match_lengths_mode == 2 || match_lengths_mode == 0, - "Only FSE_Compressed_Mode or Predefined are allowed" - ); - let _compression_mode = [ - literal_lengths_mode > 0, - offsets_mode > 0, - match_lengths_mode > 0, - ]; - let is_all_predefined_fse = literal_lengths_mode + offsets_mode + match_lengths_mode < 1; let last_state = ZstdDecodingState { @@ -391,11 +322,10 @@ fn process_sequences( byte_idx: byte_offset + num_sequence_header_bytes as u64, encoded_len: last_state.encoded_data.encoded_len, }, - bitstream_read_data: None, decoded_data: last_state.decoded_data, - fse_data: None, literal_data: last_state.literal_data, repeated_offset: last_state.repeated_offset, + last_fse_table: last_state.last_fse_table, }; ///////////////////////////////////////////////// @@ -408,50 +338,77 @@ fn process_sequences( let src = &src[num_sequence_header_bytes..]; // Literal Length Table (LLT) - let (n_fse_bytes_llt, table_llt) = FseAuxiliaryTableData::reconstruct( - src, - block_idx, - FseTableKind::LLT, - literal_lengths_mode < 2, - ) - .expect("Reconstructing FSE-packed Literl Length (LL) table should not fail."); + let (n_fse_bytes_llt, table_llt) = match literal_lengths_mode { + 0 | 2 => FseAuxiliaryTableData::reconstruct( + src, + block_idx, + FseTableKind::LLT, + literal_lengths_mode == 0, + ) + .expect("Reconstructing FSE-packed Literl Length (LL) table should not fail."), + 1 => FseAuxiliaryTableData::reconstruct_rle(src, block_idx) + .expect("Reconstructing RLE Literl Length (LL) table should not fail."), + 3 => ( + 0, + last_state.last_fse_table[0] + .clone() + .expect("Repeatd Literl Length (LL) table should be existed"), + ), + _ => unreachable!(""), + }; + let llt = table_llt.parse_state_table(); // Determine the accuracy log of LLT - let al_llt = if literal_lengths_mode > 0 { - table_llt.accuracy_log - } else { - 6 - }; + let al_llt = table_llt.accuracy_log; // Cooked Match Offset Table (CMOT) let src = &src[n_fse_bytes_llt..]; - let (n_fse_bytes_cmot, table_cmot) = - FseAuxiliaryTableData::reconstruct(src, block_idx, FseTableKind::MOT, offsets_mode < 2) - .expect("Reconstructing FSE-packed Cooked Match Offset (CMO) table should not fail."); + let (n_fse_bytes_cmot, table_cmot) = match offsets_mode { + 0 | 2 => { + FseAuxiliaryTableData::reconstruct(src, block_idx, FseTableKind::MOT, offsets_mode == 0) + .expect( + "Reconstructing FSE-packed Cooked Match Offset (CMO) table should not fail.", + ) + } + 1 => FseAuxiliaryTableData::reconstruct_rle(src, block_idx) + .expect("Reconstructing RLE Cooked Match Offset (CMO) table should not fail."), + 3 => ( + 0, + last_state.last_fse_table[1] + .clone() + .expect("Repeatd Cooked Match Offset (CMO) table should be existed"), + ), + _ => unreachable!(""), + }; + let cmot = table_cmot.parse_state_table(); // Determine the accuracy log of CMOT - let al_cmot = if offsets_mode > 0 { - table_cmot.accuracy_log - } else { - 5 - }; + let al_cmot = table_cmot.accuracy_log; // Match Length Table (MLT) let src = &src[n_fse_bytes_cmot..]; - let (n_fse_bytes_mlt, table_mlt) = FseAuxiliaryTableData::reconstruct( - src, - block_idx, - FseTableKind::MLT, - match_lengths_mode < 2, - ) - .expect("Reconstructing FSE-packed Match Length (ML) table should not fail."); + let (n_fse_bytes_mlt, table_mlt) = match match_lengths_mode { + 0 | 2 => FseAuxiliaryTableData::reconstruct( + src, + block_idx, + FseTableKind::MLT, + match_lengths_mode == 0, + ) + .expect("Reconstructing FSE-packed Match Length (ML) table should not fail."), + 1 => FseAuxiliaryTableData::reconstruct_rle(src, block_idx) + .expect("Reconstructing RLE Match Length (ML) table should not fail."), + 3 => ( + 0, + last_state.last_fse_table[2] + .clone() + .expect("Repeatd Match Length (ML) table should be existed"), + ), + _ => unreachable!(""), + }; + let mlt = table_mlt.parse_state_table(); // Determine the accuracy log of MLT - let al_mlt = if match_lengths_mode > 0 { - table_mlt.accuracy_log - } else { - 6 - }; + let al_mlt = table_mlt.accuracy_log; let last_tag_len = if offsets_mode + match_lengths_mode < 1 { n_fse_bytes_llt @@ -476,11 +433,14 @@ fn process_sequences( byte_idx: byte_offset + (n_fse_bytes_llt + n_fse_bytes_cmot + n_fse_bytes_mlt) as u64, encoded_len: last_state.encoded_data.encoded_len, }, - bitstream_read_data: None, decoded_data: last_state.decoded_data, - fse_data: None, literal_data: last_state.literal_data, repeated_offset: last_state.repeated_offset, + last_fse_table: if num_of_sequences == 0 { + last_state.last_fse_table + } else { + [Some(table_llt), Some(table_cmot), Some(table_mlt)] + }, }; let byte_offset = last_state.encoded_data.byte_idx; @@ -533,11 +493,10 @@ fn process_sequences( byte_idx: byte_offset + n_sequence_data_bytes as u64, encoded_len: last_state.encoded_data.encoded_len, }, - bitstream_read_data: None, decoded_data: last_state.decoded_data, - fse_data: None, literal_data: last_state.literal_data, repeated_offset: last_state.repeated_offset, + last_fse_table: last_state.last_fse_table, }; // Exclude the leading zero section @@ -584,7 +543,7 @@ fn process_sequences( let mut is_init = true; let mut nb = nb_switch[mode][order_idx]; let bitstream_end_bit_idx = n_sequence_data_bytes * N_BITS_PER_BYTE; - let mut table_kind; + //let mut table_kind; let mut last_states: [u64; 3] = [0, 0, 0]; let mut last_symbols: [u64; 3] = [0, 0, 0]; let mut current_decoding_state; @@ -606,17 +565,17 @@ fn process_sequences( current_decoding_state = (mode * 3 + order_idx) as u64; - table_kind = match new_decoded.0 { - SequenceDataTag::CookedMatchOffsetFse | SequenceDataTag::CookedMatchOffsetValue => { - table_cmot.table_kind as u64 - } - SequenceDataTag::MatchLengthFse | SequenceDataTag::MatchLengthValue => { - table_mlt.table_kind as u64 - } - SequenceDataTag::LiteralLengthFse | SequenceDataTag::LiteralLengthValue => { - table_llt.table_kind as u64 - } - }; + // table_kind = match new_decoded.0 { + // SequenceDataTag::CookedMatchOffsetFse | SequenceDataTag::CookedMatchOffsetValue => { + // table_cmot.table_kind as u64 + // } + // SequenceDataTag::MatchLengthFse | SequenceDataTag::MatchLengthValue => { + // table_mlt.table_kind as u64 + // } + // SequenceDataTag::LiteralLengthFse | SequenceDataTag::LiteralLengthValue => { + // table_llt.table_kind as u64 + // } + // }; // FSE state update step curr_baseline = state_baselines[order_idx]; @@ -653,17 +612,17 @@ fn process_sequences( current_decoding_state = (mode * 3 + order_idx) as u64; - table_kind = match new_decoded.0 { - SequenceDataTag::CookedMatchOffsetFse | SequenceDataTag::CookedMatchOffsetValue => { - table_cmot.table_kind as u64 - } - SequenceDataTag::MatchLengthFse | SequenceDataTag::MatchLengthValue => { - table_mlt.table_kind as u64 - } - SequenceDataTag::LiteralLengthFse | SequenceDataTag::LiteralLengthValue => { - table_llt.table_kind as u64 - } - }; + // table_kind = match new_decoded.0 { + // SequenceDataTag::CookedMatchOffsetFse | SequenceDataTag::CookedMatchOffsetValue => { + // table_cmot.table_kind as u64 + // } + // SequenceDataTag::MatchLengthFse | SequenceDataTag::MatchLengthValue => { + // table_mlt.table_kind as u64 + // } + // SequenceDataTag::LiteralLengthFse | SequenceDataTag::LiteralLengthValue => { + // table_llt.table_kind as u64 + // } + // }; // Value decoding step curr_baseline = decoding_baselines[order_idx]; @@ -683,11 +642,12 @@ fn process_sequences( // a separate row needs to be added for each of such byte to ensure continuity of the value // accumulators. These compensating rows have is_nil=true. At most, two bytes can be // entirely covered by a bitstream read operation. - let multi_byte_boundaries: [usize; 2] = [15, 23]; + let multi_byte_boundaries: [usize; 3] = [15, 23, 31]; let mut skipped_bits = 0usize; for boundary in multi_byte_boundaries { if to_bit_idx >= boundary { + // TODO: increase 8 times, can be optimized // Skip over covered bytes for byte and bit index for _ in 0..N_BITS_PER_BYTE { (current_byte_idx, current_bit_idx) = @@ -702,9 +662,10 @@ fn process_sequences( match to_bit_idx { 15 => 8, 16..=23 => 16, + 24..=31 => 24, v => unreachable!( "unexpected bit_index_end={:?} in (table={:?}, update_f?={:?}) (bit_index_start={:?}, bitstring_len={:?})", - v, table_kind, (current_decoding_state >= 3), from_bit_idx, to_bit_idx - from_bit_idx + 1, + v, order_idx, (current_decoding_state >= 3), from_bit_idx, to_bit_idx - from_bit_idx + 1, ), }; } @@ -869,11 +830,10 @@ fn process_sequences( byte_idx: byte_offset + n_sequence_data_bytes as u64, encoded_len: last_state.encoded_data.encoded_len, }, - bitstream_read_data: None, decoded_data: decoded_bytes, - fse_data: None, literal_data: Vec::new(), repeated_offset, + last_fse_table: last_state.last_fse_table, }) } @@ -928,38 +888,16 @@ fn process_block_zstd_literals_header( encoded_len: last_state.encoded_data.encoded_len, }, decoded_data: last_state.decoded_data, - bitstream_read_data: None, - fse_data: None, literal_data: Vec::new(), repeated_offset: last_state.repeated_offset, + last_fse_table: last_state.last_fse_table, }, regen_size, )) } -// Result for processing multiple blocks from compressed data -// #[derive(Debug, Clone)] -// pub struct MultiBlockProcessResult { -// pub witness_rows: Vec>, -// pub literal_bytes: Vec>, // literals -// pub fse_aux_tables: Vec, -// pub block_info_arr: Vec, -// pub sequence_info_arr: Vec, -// pub address_table_rows: Vec>, -// pub sequence_exec_results: Vec, -// } - /// Process a slice of bytes into decompression circuit witness rows pub fn process(src: &[u8]) -> Result { - // let mut witness_rows = vec![]; - // let mut decoded_bytes: Vec = vec![]; - // let mut literals: Vec> = vec![]; - // let mut fse_aux_tables: Vec = vec![]; - // let mut block_info_arr: Vec = vec![]; - // let mut sequence_info_arr: Vec = vec![]; - // let mut address_table_arr: Vec> = vec![]; - // let mut sequence_exec_info_arr: Vec = vec![]; - // // FrameHeaderDescriptor and FrameContentSize let (_frame_content_size, mut last_state) = process_frame_header(src, ZstdDecodingState::init(src.len()))?; @@ -998,36 +936,62 @@ mod tests { init_zstd_encoder_n(target_block_size.unwrap_or(N_BLOCK_SIZE_TARGET)) } - #[test] - fn test_zstd_witness_processing_batch_data() -> Result<(), std::io::Error> { + fn test_processing(data: &[u8]) -> Result, std::io::Error> { use super::*; + let compressed = { + // compression level = 0 defaults to using level=3, which is zstd's default. + let mut encoder = init_zstd_encoder(None); + + // set source length, which will be reflected in the frame header. + encoder.window_log(24)?; + // reduce compressed block size, so we have more chance to test fse repeated mode header + encoder.set_target_cblock_size(Some(16384))?; + encoder.set_pledged_src_size(Some(data.len() as u64))?; + + encoder.write_all(data)?; + encoder.finish()? + }; + let state = process(&compressed).unwrap(); + Ok(state.decoded_data) + } + + fn read_sample() -> Result>, std::io::Error> { let mut batch_files = fs::read_dir("./data/test_batches")? .map(|entry| entry.map(|e| e.path())) .collect::, std::io::Error>>()?; batch_files.sort(); - let batches = batch_files - .iter() + Ok(batch_files + .into_iter() .map(fs::read_to_string) .filter_map(|data| data.ok()) - .map(|data| hex::decode(data.trim_end()).expect("Failed to decode hex data")) - .collect::>>(); + .map(|data| hex::decode(data.trim_end()).expect("Failed to decode hex data"))) + } - for raw_input_bytes in batches.into_iter() { - let compressed = { - // compression level = 0 defaults to using level=3, which is zstd's default. - let mut encoder = init_zstd_encoder(None); + #[test] + fn test_zstd_witness_processing_batch_data() -> Result<(), std::io::Error> { + for raw_input_bytes in read_sample()? { + let decoded_bytes = test_processing(&raw_input_bytes)?; - // set source length, which will be reflected in the frame header. - encoder.set_pledged_src_size(Some(raw_input_bytes.len() as u64))?; + assert!(raw_input_bytes == decoded_bytes); + } - encoder.write_all(&raw_input_bytes)?; - encoder.finish()? - }; + Ok(()) + } - let state = process(&compressed).unwrap(); + #[test] + fn test_zstd_witness_processing_rle_data() -> Result<(), std::io::Error> { + for mut raw_input_bytes in read_sample()? { + // construct rle block and long-ref + if raw_input_bytes.len() < 128 * 1024 { + let cur = raw_input_bytes.clone(); + // construct an rle + raw_input_bytes.resize(256 * 1024, 42u8); + // then we can have a long-distance ref + raw_input_bytes.extend(cur); + } - let decoded_bytes = state.decoded_data; + let decoded_bytes = test_processing(&raw_input_bytes)?; assert!(raw_input_bytes == decoded_bytes); } diff --git a/src/fse.rs b/src/fse.rs index 2d908e9..06b7a10 100644 --- a/src/fse.rs +++ b/src/fse.rs @@ -17,8 +17,6 @@ pub struct FseTableRow { pub num_bits: u64, /// The symbol emitted by the FSE table at this state. pub symbol: u64, - /// During FSE table decoding, keep track of the number of symbol emitted - pub num_emitted: u64, /// A boolean marker to indicate that as per the state transition rules of FSE codes, this /// state was reached for this symbol, however it was already pre-allocated to a prior symbol, /// this can happen in case we have symbols with prob=-1. @@ -32,8 +30,8 @@ pub struct FseAuxiliaryTableData { pub block_idx: u64, /// Indicates whether the table is pre-defined. pub is_predefined: bool, - /// The FSE table kind, variants are: LLT=1, MOT=2, MLT=3. - pub table_kind: FseTableKind, + /// In RLE mode, record the rle symbol. + pub rle_symbol: Option, /// The FSE table's size, i.e. 1 << AL (accuracy log). pub table_size: u64, /// The accuracy log @@ -46,8 +44,8 @@ pub struct FseAuxiliaryTableData { /// /// For each symbol, the states as per the state transition rule. pub sym_to_states: BTreeMap>, - /// Similar map, but where the states for each symbol are in increasing order (sorted). - pub sym_to_sorted_states: BTreeMap>, + // Similar map, but where the states for each symbol are in increasing order (sorted). + // pub sym_to_sorted_states: BTreeMap>, } /// Another form of Fse table that has state as key instead of the FseSymbol. @@ -58,6 +56,36 @@ type FseStateMapping = BTreeMap; type ReconstructedFse = (usize, FseAuxiliaryTableData); impl FseAuxiliaryTableData { + pub fn reconstruct_rle(src: &[u8], block_idx: u64) -> std::io::Result { + let symbol = src[0] as u64; + let mut sym_to_states = BTreeMap::new(); + sym_to_states.insert( + symbol, + vec![FseTableRow { + state: 0, + baseline: 0, + num_bits: 0, + symbol, + is_state_skipped: false, + }], + ); + let mut normalised_probs = BTreeMap::new(); + normalised_probs.insert(symbol, 1); + + Ok(( + 1, + Self { + block_idx, + is_predefined: false, + rle_symbol: Some(src[0]), + table_size: 1, + accuracy_log: 0, + normalised_probs, + sym_to_states, + }, + )) + } + /// While we reconstruct an FSE table from a bitstream, we do not know before reconstruction /// how many exact bytes we would finally be reading. /// @@ -73,8 +101,8 @@ impl FseAuxiliaryTableData { is_predefined: bool, ) -> std::io::Result { // construct little-endian bit-reader. - let data = src.to_vec(); - let mut reader = BitReader::endian(Cursor::new(&data), LittleEndian); + // let data = src.to_vec(); + let mut reader = BitReader::endian(Cursor::new(src), LittleEndian); //////////////////////////////////////////////////////////////////////////////////////// //////////////////////////// Parse Normalised Probabilities //////////////////////////// @@ -124,7 +152,7 @@ impl FseAuxiliaryTableData { // number of bits and value read from the variable bit-packed data. // And update the total number of bits read so far. let (n_bits_read, _value_read, value_decoded) = - read_variable_bit_packing(&data, offset, R + 1)?; + read_variable_bit_packing(src, offset, R + 1)?; reader.skip(n_bits_read)?; offset += n_bits_read; @@ -207,20 +235,19 @@ impl FseAuxiliaryTableData { //////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////// Allocate States to Symbols /////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////// - let (sym_to_states, sym_to_sorted_states) = - Self::transform_normalised_probs(&normalised_probs, accuracy_log); + let sym_to_states = Self::transform_normalised_probs(&normalised_probs, accuracy_log); Ok(( t, Self { block_idx, is_predefined, - table_kind, + rle_symbol: None, table_size, accuracy_log: accuracy_log as u64, normalised_probs, sym_to_states, - sym_to_sorted_states, + //sym_to_sorted_states, }, )) } @@ -229,15 +256,11 @@ impl FseAuxiliaryTableData { fn transform_normalised_probs( normalised_probs: &BTreeMap, accuracy_log: u8, - ) -> ( - BTreeMap>, - BTreeMap>, - ) { + ) -> BTreeMap> { // TODO: still need optimizations let table_size = 1 << accuracy_log; let mut sym_to_states = BTreeMap::new(); - let mut sym_to_sorted_states = BTreeMap::new(); let mut state = 0; let mut retreating_state = table_size - 1; let mut allocated_states = BTreeMap::::new(); @@ -254,10 +277,8 @@ impl FseAuxiliaryTableData { baseline: 0, symbol, is_state_skipped: false, - num_emitted: 0, }; sym_to_states.insert(symbol, vec![fse_table_row.clone()]); - sym_to_sorted_states.insert(symbol, vec![fse_table_row]); retreating_state -= 1; } @@ -321,31 +342,14 @@ impl FseAuxiliaryTableData { num_bits: nb, baseline, symbol, - num_emitted: 0, is_state_skipped, } }) .collect(), ); - sym_to_sorted_states.insert( - symbol, - sorted_states - .iter() - .zip(nbs.iter()) - .zip(baselines.iter()) - .map(|((&s, &nb), &baseline)| FseTableRow { - state: s, - num_bits: nb, - baseline, - symbol, - num_emitted: 0, - is_state_skipped: false, - }) - .collect(), - ); } - (sym_to_states, sym_to_sorted_states) + sym_to_states } /// Convert an FseAuxiliaryTableData into a state-mapped representation. @@ -387,10 +391,19 @@ mod tests { // TODO: assert equality for the entire table. // for now only comparing state/baseline/nb for S1, i.e. weight == 1. + let sorted_states = table + .sym_to_states + .get(&1) + .unwrap() + .iter() + .filter(|st| !st.is_state_skipped) + .sorted_by_key(|s| s.state) + .cloned() + .collect::>(); assert_eq!(n_bytes, 4); assert_eq!( - table.sym_to_sorted_states.get(&1).cloned().unwrap(), + sorted_states, [ (0x03, 0x10, 3), (0x0c, 0x18, 3), @@ -405,7 +418,6 @@ mod tests { symbol: 1, baseline, num_bits, - num_emitted: 0, is_state_skipped: false, }) .collect::>(), diff --git a/src/types.rs b/src/types.rs index 3f1984c..d777143 100644 --- a/src/types.rs +++ b/src/types.rs @@ -472,16 +472,14 @@ pub struct ZstdDecodingState { pub state: ZstdState, /// Data cursor on compressed data pub encoded_data: EncodedDataCursor, - /// Bitstream reader cursor - pub bitstream_read_data: Option, /// decompressed data has been decoded pub decoded_data: Vec, - /// Fse decoding state transition data - pub fse_data: Option, /// literal dicts pub literal_data: Vec, /// the repeated offset for sequence pub repeated_offset: [usize; 3], + /// the cached fse table for repeated mode + pub last_fse_table: [Option; 3], } impl ZstdDecodingState { @@ -494,10 +492,9 @@ impl ZstdDecodingState { ..Default::default() }, decoded_data: Vec::new(), - fse_data: None, - bitstream_read_data: None, literal_data: Vec::new(), repeated_offset: [1, 4, 8], // starting values, according to the spec + last_fse_table: Default::default(), } } }