From 04a7286badf0a7c6773eecda2adb15ab4a6bade0 Mon Sep 17 00:00:00 2001 From: Rohit Narurkar Date: Mon, 6 May 2024 18:14:46 +0100 Subject: [PATCH 1/2] fix: account for variable bit-packing in fse code section --- aggregator/src/aggregation/decoder.rs | 172 ++++++++++++------ .../src/aggregation/decoder/tables/fixed.rs | 9 + .../tables/fixed/variable_bit_packing.rs | 87 +++++++++ 3 files changed, 209 insertions(+), 59 deletions(-) create mode 100644 aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs diff --git a/aggregator/src/aggregation/decoder.rs b/aggregator/src/aggregation/decoder.rs index 492e6397c7..69f794cf52 100644 --- a/aggregator/src/aggregation/decoder.rs +++ b/aggregator/src/aggregation/decoder.rs @@ -443,18 +443,11 @@ pub struct BitstreamDecoder { bit_index_end_cmp_23: ComparatorConfig, /// The value of the binary bitstring. bitstring_value: Column, - /// Helper gadget to know when the bitstring value is 0. This contributes to an edge-case in - /// decoding and reconstructing the FSE table from normalised distributions, where a value=0 - /// implies prob=-1 ("less than 1" probability). In this case, the symbol is allocated a state - /// at the end of the FSE table, with baseline=0x00 and nb=AL, i.e. reset state. - bitstring_value_eq_0: IsEqualConfig, - /// Helper gadget to know when the bitstring value is 1 or 3. This is useful in the case - /// of decoding/reconstruction of FSE table, where a value=1 implies a special case of - /// prob=0, where the symbol is instead followed by a 2-bit repeat flag. The repeat flag - /// bits themselves could be followed by another 2-bit repeat flag if the repeat flag's - /// value is 3. - bitstring_value_eq_1: IsEqualConfig, - /// Helper config as per the above doc. + /// When we have encountered a symbol with value=1, i.e. prob=0, it is followed by 2-bits + /// repeat bits flag that tells us the number of symbols following the current one that also + /// have a probability of prob=0. If the repeat bits flag itself is [1, 1], i.e. + /// bitstring_value==3, then it is followed by another 2-bits repeat bits flag and so on. We + /// utilise this equality config to identify these cases. bitstring_value_eq_3: IsEqualConfig, /// Boolean that is set for a special case: /// - The bitstring that we have read in the current row is byte-aligned up to the next or the @@ -504,18 +497,6 @@ impl BitstreamDecoder { u8_table.into(), ), bitstring_value, - bitstring_value_eq_0: IsEqualChip::configure( - meta, - |meta| not::expr(meta.query_advice(is_padding, Rotation::cur())), - |meta| meta.query_advice(bitstring_value, Rotation::cur()), - |_| 0.expr(), - ), - bitstring_value_eq_1: IsEqualChip::configure( - meta, - |meta| not::expr(meta.query_advice(is_padding, Rotation::cur())), - |meta| meta.query_advice(bitstring_value, Rotation::cur()), - |_| 1.expr(), - ), bitstring_value_eq_3: IsEqualChip::configure( meta, |meta| not::expr(meta.query_advice(is_padding, Rotation::cur())), @@ -552,25 +533,6 @@ impl BitstreamDecoder { meta.query_advice(self.is_nb0, rotation) } - /// If the bitstring value is 0. - fn is_prob_less_than1( - &self, - meta: &mut VirtualCells, - rotation: Rotation, - ) -> Expression { - let bitstring_value = meta.query_advice(self.bitstring_value, rotation); - self.bitstring_value_eq_0 - .expr_at(meta, rotation, bitstring_value, 1.expr()) - } - - /// While reconstructing the FSE table, indicates whether a value=1 was found, i.e. prob=0. In - /// this case, the symbol is followed by 2-bits repeat flag instead. - fn is_prob0(&self, meta: &mut VirtualCells, rotation: Rotation) -> Expression { - let bitstring_value = meta.query_advice(self.bitstring_value, rotation); - self.bitstring_value_eq_1 - .expr_at(meta, rotation, bitstring_value, 1.expr()) - } - /// Whether the 2-bits repeat flag was [1, 1]. In this case, the repeat flag is followed by /// another repeat flag. fn is_rb_flag3(&self, meta: &mut VirtualCells, rotation: Rotation) -> Expression { @@ -702,6 +664,8 @@ pub struct FseDecoder { table_size: Column, /// The incremental symbol for which probability is decoded. symbol: Column, + /// The value decoded as per variable bit-packing. + value_decoded: Column, /// An accumulator of the number of states allocated to each symbol as we decode the FSE table. /// This is the normalised probability for the symbol. probability_acc: Column, @@ -709,17 +673,40 @@ pub struct FseDecoder { is_repeat_bits_loop: Column, /// Whether this row represents the 0-7 trailing bits that should be ignored. is_trailing_bits: Column, + /// Helper gadget to know when the decoded value is 0. This contributes to an edge-case in + /// decoding and reconstructing the FSE table from normalised distributions, where a value=0 + /// implies prob=-1 ("less than 1" probability). In this case, the symbol is allocated a state + /// at the end of the FSE table, with baseline=0x00 and nb=AL, i.e. reset state. + value_decoded_eq_0: IsEqualConfig, + /// Helper gadget to know when the decoded value is 1. This is useful in the edge-case in + /// decoding and reconstructing the FSE table, where a value=1 implies a special case of + /// prob=0, where the symbol is instead followed by a 2-bit repeat flag. + value_decoded_eq_1: IsEqualConfig, } impl FseDecoder { - fn configure(meta: &mut ConstraintSystem) -> Self { + fn configure(meta: &mut ConstraintSystem, is_padding: Column) -> Self { + let value_decoded = meta.advice_column(); Self { table_kind: meta.advice_column(), table_size: meta.advice_column(), symbol: meta.advice_column(), + value_decoded, probability_acc: meta.advice_column(), is_repeat_bits_loop: meta.advice_column(), is_trailing_bits: meta.advice_column(), + value_decoded_eq_0: IsEqualChip::configure( + meta, + |meta| not::expr(meta.query_advice(is_padding, Rotation::cur())), + |meta| meta.query_advice(value_decoded, Rotation::cur()), + |_| 0.expr(), + ), + value_decoded_eq_1: IsEqualChip::configure( + meta, + |meta| not::expr(meta.query_advice(is_padding, Rotation::cur())), + |meta| meta.query_advice(value_decoded, Rotation::cur()), + |_| 1.expr(), + ), } } } @@ -746,6 +733,25 @@ impl FseDecoder { * (table_kind.expr() - FseTableKind::MLT.expr()) * invert_of_2 } + + /// If the decoded value is 0. + fn is_prob_less_than1( + &self, + meta: &mut VirtualCells, + rotation: Rotation, + ) -> Expression { + let value_decoded = meta.query_advice(self.value_decoded, rotation); + self.value_decoded_eq_0 + .expr_at(meta, rotation, value_decoded, 1.expr()) + } + + /// While reconstructing the FSE table, indicates whether a value=1 was found, i.e. prob=0. In + /// this case, the symbol is followed by 2-bits repeat flag instead. + fn is_prob0(&self, meta: &mut VirtualCells, rotation: Rotation) -> Expression { + let value_decoded = meta.query_advice(self.value_decoded, rotation); + self.value_decoded_eq_1 + .expr_at(meta, rotation, value_decoded, 1.expr()) + } } #[derive(Clone, Debug)] @@ -964,7 +970,7 @@ impl DecoderConfig { let sequences_header_decoder = SequencesHeaderDecoder::configure(meta, byte, is_padding, u8_table); let bitstream_decoder = BitstreamDecoder::configure(meta, is_padding, u8_table); - let fse_decoder = FseDecoder::configure(meta); + let fse_decoder = FseDecoder::configure(meta, is_padding); let sequences_data_decoder = SequencesDataDecoder::configure(meta); // TODO(enable): @@ -2143,7 +2149,7 @@ impl DecoderConfig { cb.condition( and::expr([ not::expr(is_repeat_bits_loop.expr()), - config.bitstream_decoder.is_prob0(meta, Rotation::cur()), + config.fse_decoder.is_prob0(meta, Rotation::cur()), ]), |cb| { cb.require_equal( @@ -2193,12 +2199,20 @@ impl DecoderConfig { // updating and the FSE symbol itself. // // If no bitstring was read, even the symbol value is carried forward. - let (prob_acc_cur, prob_acc_prev, fse_symbol_cur, fse_symbol_prev, value) = ( + let ( + prob_acc_cur, + prob_acc_prev, + fse_symbol_cur, + fse_symbol_prev, + bitstring_value, + value_decoded, + ) = ( meta.query_advice(config.fse_decoder.probability_acc, Rotation::cur()), meta.query_advice(config.fse_decoder.probability_acc, Rotation::prev()), meta.query_advice(config.fse_decoder.symbol, Rotation::cur()), meta.query_advice(config.fse_decoder.symbol, Rotation::prev()), meta.query_advice(config.bitstream_decoder.bitstring_value, Rotation::cur()), + meta.query_advice(config.fse_decoder.value_decoded, Rotation::cur()), ); cb.condition( config.bitstream_decoder.is_nil(meta, Rotation::cur()), @@ -2231,11 +2245,9 @@ impl DecoderConfig { "fse: probability_acc is updated correctly", prob_acc_cur.expr(), select::expr( - config - .bitstream_decoder - .is_prob_less_than1(meta, Rotation::cur()), + config.fse_decoder.is_prob_less_than1(meta, Rotation::cur()), prob_acc_prev.expr() + 1.expr(), - prob_acc_prev.expr() + value.expr() - 1.expr(), + prob_acc_prev.expr() + value_decoded.expr() - 1.expr(), ), ); cb.require_equal( @@ -2268,7 +2280,7 @@ impl DecoderConfig { cb.require_equal( "fse: repeat-bits increases by the 2-bit value", fse_symbol_cur, - fse_symbol_prev + value, + fse_symbol_prev + bitstring_value, ); }); @@ -2379,6 +2391,49 @@ impl DecoderConfig { }, ); + meta.lookup_any( + "DecoderConfig: tag ZstdBlockSequenceFseCode (variable bit-packing)", + |meta| { + // At every row where a non-nil bitstring is read: + // - except the AL bits (is_change=true) + // - except when we are in repeat-bits loop + // - except the trailing bits (if they exist) + let condition = and::expr([ + meta.query_advice(config.tag_config.is_fse_code, Rotation::cur()), + config.bitstream_decoder.is_not_nil(meta, Rotation::cur()), + not::expr(meta.query_advice(config.tag_config.is_change, Rotation::cur())), + not::expr( + meta.query_advice(config.fse_decoder.is_repeat_bits_loop, Rotation::cur()), + ), + ]); + + let (table_size, probability_acc, value_read, value_decoded, num_bits) = ( + meta.query_advice(config.fse_decoder.table_size, Rotation::cur()), + meta.query_advice(config.fse_decoder.probability_acc, Rotation::prev()), + meta.query_advice(config.bitstream_decoder.bitstring_value, Rotation::cur()), + meta.query_advice(config.fse_decoder.value_decoded, Rotation::cur()), + config + .bitstream_decoder + .bitstring_len_unchecked(meta, Rotation::cur()), + ); + + let range = table_size - probability_acc + 1.expr(); + [ + FixedLookupTag::VariableBitPacking.expr(), + range, + value_read, + value_decoded, + num_bits, + 0.expr(), + 0.expr(), + ] + .into_iter() + .zip_eq(config.fixed_table.table_exprs(meta)) + .map(|(arg, table)| (condition.expr() * arg, table)) + .collect() + }, + ); + meta.lookup_any( "DecoderConfig: tag ZstdBlockSequenceFseCode (normalised probability of symbol)", |meta| { @@ -2391,7 +2446,7 @@ impl DecoderConfig { meta.query_advice(config.tag_config.is_fse_code, Rotation::cur()), config.bitstream_decoder.is_not_nil(meta, Rotation::cur()), not::expr(meta.query_advice(config.tag_config.is_change, Rotation::cur())), - not::expr(config.bitstream_decoder.is_prob0(meta, Rotation::cur())), + not::expr(config.fse_decoder.is_prob0(meta, Rotation::cur())), not::expr( meta.query_advice(config.fse_decoder.is_repeat_bits_loop, Rotation::cur()), ), @@ -2400,20 +2455,19 @@ impl DecoderConfig { ), ]); - let (block_idx, fse_table_kind, fse_table_size, fse_symbol, bitstring_value) = ( + let (block_idx, fse_table_kind, fse_table_size, fse_symbol, value_decoded) = ( meta.query_advice(config.block_config.block_idx, Rotation::cur()), meta.query_advice(config.fse_decoder.table_kind, Rotation::cur()), meta.query_advice(config.fse_decoder.table_size, Rotation::cur()), meta.query_advice(config.fse_decoder.symbol, Rotation::cur()), - meta.query_advice(config.bitstream_decoder.bitstring_value, Rotation::cur()), + meta.query_advice(config.fse_decoder.value_decoded, Rotation::cur()), ); - let is_prob_less_than1 = config - .bitstream_decoder - .is_prob_less_than1(meta, Rotation::cur()); + let is_prob_less_than1 = + config.fse_decoder.is_prob_less_than1(meta, Rotation::cur()); let norm_prob = select::expr( is_prob_less_than1.expr(), 1.expr(), - bitstring_value - 1.expr(), + value_decoded - 1.expr(), ); [ diff --git a/aggregator/src/aggregation/decoder/tables/fixed.rs b/aggregator/src/aggregation/decoder/tables/fixed.rs index 7edd6ed5ff..0a78eacd6e 100644 --- a/aggregator/src/aggregation/decoder/tables/fixed.rs +++ b/aggregator/src/aggregation/decoder/tables/fixed.rs @@ -29,6 +29,9 @@ use seq_tag_order::RomSeqTagOrder; mod tag_transition; use tag_transition::RomTagTransition; +mod variable_bit_packing; +use variable_bit_packing::RomVariableBitPacking; + pub trait FixedLookupValues { fn values() -> Vec<[Value; 7]>; } @@ -52,6 +55,11 @@ pub enum FixedLookupTag { /// Represents the FSE table reconstructed from the default distributions, i.e. Predefined FSE /// table. PredefinedFse, + /// Represents read and decoded values for the variable bit-packing as specified in the [zstd + /// comopression format][doclink]: + /// + /// doclink: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse-table-description + VariableBitPacking, } impl_expr!(FixedLookupTag); @@ -65,6 +73,7 @@ impl FixedLookupTag { Self::SeqCodeToValue => RomSeqCodeToValue::values(), Self::FseTableTransition => RomFseTableTransition::values(), Self::PredefinedFse => RomPredefinedFse::values(), + Self::VariableBitPacking => RomVariableBitPacking::values(), } } } diff --git a/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs b/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs new file mode 100644 index 0000000000..5f8ec18c6e --- /dev/null +++ b/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs @@ -0,0 +1,87 @@ +use halo2_proofs::{circuit::Value, halo2curves::bn256::Fr}; + +use crate::aggregation::decoder::witgen::util::bit_length; + +use super::{FixedLookupTag, FixedLookupValues}; + +#[derive(Clone, Debug)] +pub struct RomVariableBitPacking { + range: u64, + value_read: u64, + value_decoded: u64, + num_bits: u64, +} + +impl FixedLookupValues for RomVariableBitPacking { + fn values() -> Vec<[Value; 7]> { + // The maximum range R we ever have is 512 (1 << 9) as the maximum possible accuracy log is + // 9. So we only need to support a range up to R + 1, i.e. 513. + let rows = (0..=513) + .flat_map(|range| { + // Get the number of bits required to represent the highest number in this range. + let size = bit_length(range) as u32; + let max = 1 << size; + + // Whether ``range`` is a power of 2 minus 1, i.e. 2^k - 1. In these cases, we + // don't need variable bit-packing as all values in the range can be represented by + // the same number of bits. + let is_no_var = range & (range + 1) == 0; + + // The value read is in fact the value decoded. + if is_no_var { + return (0..=range) + .map(|value_read| RomVariableBitPacking { + range, + value_read, + value_decoded: value_read, + num_bits: size as u64, + }) + .collect::>(); + } + + let n_total = range + 1; + let lo_pin = max - n_total; + let n_remaining = n_total - lo_pin; + let hi_pin_1 = lo_pin + (n_remaining / 2); + let hi_pin_2 = max - (n_remaining / 2); + + (0..max) + .map(|value_read| { + // the value denoted by the low (size - 1)-bits. + let lo_value = value_read & ((1 << (size - 1)) - 1); + let (num_bits, value_decoded) = if (0..lo_pin).contains(&lo_value) { + (size - 1, lo_value) + } else if (lo_pin..hi_pin_1).contains(&value_read) { + (size, value_read) + } else if (hi_pin_1..hi_pin_2).contains(&value_read) { + (size - 1, value_read - hi_pin_1) + } else { + assert!((hi_pin_2..max).contains(&value_read)); + (size, value_read - lo_pin) + }; + RomVariableBitPacking { + range, + value_read, + value_decoded, + num_bits: num_bits.into(), + } + }) + .collect::>() + }) + .collect::>(); + + rows.iter() + .map(|row| { + [ + Value::known(Fr::from(FixedLookupTag::VariableBitPacking as u64)), + Value::known(Fr::from(row.range)), + Value::known(Fr::from(row.value_read)), + Value::known(Fr::from(row.value_decoded)), + Value::known(Fr::from(row.num_bits)), + Value::known(Fr::zero()), + Value::known(Fr::zero()), + ] + }) + .collect() + } +} From cda4d1de21ac6925a483a16dc2a3584082295b76 Mon Sep 17 00:00:00 2001 From: Rohit Narurkar Date: Mon, 6 May 2024 18:25:43 +0100 Subject: [PATCH 2/2] chore: range starts from 1 (ignore 0) --- .../aggregation/decoder/tables/fixed/variable_bit_packing.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs b/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs index 5f8ec18c6e..6b32c1c42f 100644 --- a/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs +++ b/aggregator/src/aggregation/decoder/tables/fixed/variable_bit_packing.rs @@ -16,7 +16,7 @@ impl FixedLookupValues for RomVariableBitPacking { fn values() -> Vec<[Value; 7]> { // The maximum range R we ever have is 512 (1 << 9) as the maximum possible accuracy log is // 9. So we only need to support a range up to R + 1, i.e. 513. - let rows = (0..=513) + let rows = (1..=513) .flat_map(|range| { // Get the number of bits required to represent the highest number in this range. let size = bit_length(range) as u32;