diff --git a/zlib-rs/src/inflate.rs b/zlib-rs/src/inflate.rs index ae2f565f..7072a572 100644 --- a/zlib-rs/src/inflate.rs +++ b/zlib-rs/src/inflate.rs @@ -492,24 +492,6 @@ impl<'a> State<'a> { } } -macro_rules! pull_byte { - ($self:expr) => { - match $self.bit_reader.pull_byte() { - Err(return_code) => return $self.inflate_leave(return_code), - Ok(_) => (), - } - }; -} - -macro_rules! need_bits { - ($self:expr, $n:expr) => { - match $self.bit_reader.need_bits($n) { - Err(return_code) => return $self.inflate_leave(return_code), - Ok(v) => v, - } - }; -} - // swaps endianness const fn zswap32(q: u32) -> u32 { u32::from_be(q.to_le()) @@ -862,12 +844,38 @@ impl State<'_> { } fn dispatch(&mut self) -> ReturnCode { - 'label: loop { - match self.mode { + // Note: All early returns must save mode into self.mode again. + let mut mode = self.mode; + + macro_rules! pull_byte { + ($self:expr) => { + match $self.bit_reader.pull_byte() { + Err(return_code) => { + self.mode = mode; + return $self.inflate_leave(return_code); + } + Ok(_) => (), + } + }; + } + + macro_rules! need_bits { + ($self:expr, $n:expr) => { + match $self.bit_reader.need_bits($n) { + Err(return_code) => { + self.mode = mode; + return $self.inflate_leave(return_code); + } + Ok(v) => v, + } + }; + } + + let ret = 'label: loop { + match mode { Mode::Head => { if self.wrap == 0 { - self.mode = Mode::TypeDo; - + mode = Mode::TypeDo; continue 'label; } @@ -884,7 +892,7 @@ impl State<'_> { self.checksum = crc32(crate::CRC32_INITIAL_VALUE, &[b0, b1]); self.bit_reader.init_bits(); - self.mode = Mode::Flags; + mode = Mode::Flags; continue 'label; } @@ -898,12 +906,12 @@ impl State<'_> { || ((self.bit_reader.bits(8) << 8) + (self.bit_reader.hold() >> 8)) % 31 != 0 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("incorrect header check\0"); } if self.bit_reader.bits(4) != Z_DEFLATED as u64 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("unknown compression method\0"); } @@ -915,7 +923,7 @@ impl State<'_> { } if len as i32 > MAX_WBITS || len > self.wbits { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid window size\0"); } @@ -926,13 +934,13 @@ impl State<'_> { if self.bit_reader.hold() & 0x200 != 0 { self.bit_reader.init_bits(); - self.mode = Mode::DictId; + mode = Mode::DictId; continue 'label; } else { self.bit_reader.init_bits(); - self.mode = Mode::Type; + mode = Mode::Type; continue 'label; } @@ -943,12 +951,12 @@ impl State<'_> { // Z_DEFLATED = 8 is the only supported method if self.gzip_flags & 0xff != Z_DEFLATED { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("unknown compression method\0"); } if self.gzip_flags & 0xe000 != 0 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("unknown header flags set\0"); } @@ -963,7 +971,7 @@ impl State<'_> { } self.bit_reader.init_bits(); - self.mode = Mode::Time; + mode = Mode::Time; continue 'label; } @@ -979,7 +987,7 @@ impl State<'_> { } self.bit_reader.init_bits(); - self.mode = Mode::Os; + mode = Mode::Os; continue 'label; } @@ -996,7 +1004,7 @@ impl State<'_> { } self.bit_reader.init_bits(); - self.mode = Mode::ExLen; + mode = Mode::ExLen; continue 'label; } @@ -1019,7 +1027,7 @@ impl State<'_> { head.extra = core::ptr::null_mut(); } - self.mode = Mode::Extra; + mode = Mode::Extra; continue 'label; } @@ -1082,7 +1090,7 @@ impl State<'_> { } self.length = 0; - self.mode = Mode::Name; + mode = Mode::Name; continue 'label; } @@ -1140,7 +1148,7 @@ impl State<'_> { } self.length = 0; - self.mode = Mode::Comment; + mode = Mode::Comment; continue 'label; } @@ -1197,7 +1205,7 @@ impl State<'_> { head.comment = core::ptr::null_mut(); } - self.mode = Mode::HCrc; + mode = Mode::HCrc; continue 'label; } @@ -1208,7 +1216,7 @@ impl State<'_> { if (self.wrap & 4) != 0 && self.bit_reader.hold() as u32 != (self.checksum & 0xffff) { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("header crc mismatch\0"); } @@ -1226,7 +1234,7 @@ impl State<'_> { self.checksum = crate::CRC32_INITIAL_VALUE; } - self.mode = Mode::Type; + mode = Mode::Type; continue 'label; } @@ -1237,7 +1245,7 @@ impl State<'_> { Block | Trees => break 'label ReturnCode::Ok, NoFlush | SyncFlush | Finish => { // NOTE: this is slightly different to what zlib-rs does! - self.mode = Mode::TypeDo; + mode = Mode::TypeDo; continue 'label; } } @@ -1245,7 +1253,7 @@ impl State<'_> { Mode::TypeDo => { if self.flags.contains(Flags::IS_LAST_BLOCK) { self.bit_reader.next_byte_boundary(); - self.mode = Mode::Check; + mode = Mode::Check; continue 'label; } @@ -1262,7 +1270,7 @@ impl State<'_> { self.bit_reader.drop_bits(2); - self.mode = Mode::Stored; + mode = Mode::Stored; continue 'label; } @@ -1279,7 +1287,7 @@ impl State<'_> { bits: 5, }; - self.mode = Mode::Len_; + mode = Mode::Len_; self.bit_reader.drop_bits(2); @@ -1294,7 +1302,7 @@ impl State<'_> { self.bit_reader.drop_bits(2); - self.mode = Mode::Table; + mode = Mode::Table; continue 'label; } @@ -1303,7 +1311,7 @@ impl State<'_> { self.bit_reader.drop_bits(2); - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid block type\0"); } _ => { @@ -1322,7 +1330,7 @@ impl State<'_> { // eprintln!("hold {hold:#x}"); if hold as u16 != !((hold >> 16) as u16) { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid stored block lengths\0"); } @@ -1334,7 +1342,7 @@ impl State<'_> { if let InflateFlush::Trees = self.flush { break 'label self.inflate_leave(ReturnCode::Ok); } else { - self.mode = Mode::CopyBlock; + mode = Mode::CopyBlock; continue 'label; } @@ -1360,7 +1368,7 @@ impl State<'_> { self.length -= copy; } - self.mode = Mode::Type; + mode = Mode::Type; continue 'label; } @@ -1388,25 +1396,30 @@ impl State<'_> { self.out_available = self.writer.capacity() - self.writer.len(); if self.wrap & 4 != 0 && given_checksum != self.checksum { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("incorrect data check\0"); } self.bit_reader.init_bits(); } - self.mode = Mode::Length; + mode = Mode::Length; continue 'label; } Mode::Len_ => { - self.mode = Mode::Len; + mode = Mode::Len; continue 'label; } - Mode::Len => match self.len_and_friends() { - ControlFlow::Break(return_code) => break 'label return_code, - ControlFlow::Continue(()) => continue 'label, - }, + Mode::Len => { + self.mode = mode; + let val = self.len_and_friends(); + mode = self.mode; + match val { + ControlFlow::Break(return_code) => break 'label return_code, + ControlFlow::Continue(()) => continue 'label, + } + } Mode::LenExt => { // NOTE: this branch must be kept in sync with its counterpart in `len_and_friends` let extra = self.extra; @@ -1422,7 +1435,7 @@ impl State<'_> { // eprintln!("inflate: length {}", state.length); self.was = self.length; - self.mode = Mode::Dist; + mode = Mode::Dist; continue 'label; } @@ -1436,7 +1449,7 @@ impl State<'_> { self.writer.push(self.length as u8); - self.mode = Mode::Len; + mode = Mode::Len; continue 'label; } @@ -1477,14 +1490,14 @@ impl State<'_> { self.bit_reader.drop_bits(here.bits); if here.op & 64 != 0 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid distance code\0"); } self.offset = here.val as usize; self.extra = (here.op & MAX_BITS) as usize; - self.mode = Mode::DistExt; + mode = Mode::DistExt; continue 'label; } @@ -1500,13 +1513,13 @@ impl State<'_> { } if INFLATE_STRICT && self.offset > self.dmax { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid distance code too far back\0"); } // eprintln!("inflate: distance {}", state.offset); - self.mode = Mode::Match; + mode = Mode::Match; continue 'label; } @@ -1533,7 +1546,7 @@ impl State<'_> { if copy > self.window.have() { if self.flags.contains(Flags::SANE) { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid distance too far back\0"); } @@ -1568,7 +1581,7 @@ impl State<'_> { self.length -= copy; if self.length == 0 { - self.mode = Mode::Len; + mode = Mode::Len; continue 'label; } else { @@ -1589,12 +1602,12 @@ impl State<'_> { // TODO pkzit_bug_workaround if self.nlen > 286 || self.ndist > 30 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("too many length or distance symbols\0"); } self.have = 0; - self.mode = Mode::LenLens; + mode = Mode::LenLens; continue 'label; } @@ -1626,7 +1639,7 @@ impl State<'_> { self.len_table.bits, &mut self.work, ) else { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid code lengths set\0"); }; @@ -1634,7 +1647,7 @@ impl State<'_> { self.len_table.bits = root; self.have = 0; - self.mode = Mode::CodeLens; + mode = Mode::CodeLens; continue 'label; } @@ -1662,7 +1675,7 @@ impl State<'_> { need_bits!(self, here_bits as usize + 2); self.bit_reader.drop_bits(here_bits); if self.have == 0 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid bit length repeat\0"); } @@ -1671,7 +1684,7 @@ impl State<'_> { self.bit_reader.drop_bits(2); if self.have + copy > self.nlen + self.ndist { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid bit length repeat\0"); } @@ -1688,7 +1701,7 @@ impl State<'_> { self.bit_reader.drop_bits(3); if self.have + copy > self.nlen + self.ndist { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid bit length repeat\0"); } @@ -1705,7 +1718,7 @@ impl State<'_> { self.bit_reader.drop_bits(7); if self.have + copy > self.nlen + self.ndist { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid bit length repeat\0"); } @@ -1719,7 +1732,7 @@ impl State<'_> { // check for end-of-block code (better have one) if self.lens[256] == 0 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid code -- missing end-of-block\0"); } @@ -1735,7 +1748,7 @@ impl State<'_> { self.len_table.bits, &mut self.work, ) else { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid literal/lengths set\0"); }; @@ -1752,14 +1765,14 @@ impl State<'_> { self.dist_table.bits, &mut self.work, ) else { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("invalid distances set\0"); }; self.dist_table.bits = root; self.dist_table.codes = Codes::Dist; - self.mode = Mode::Len_; + mode = Mode::Len_; if matches!(self.flush, InflateFlush::Trees) { break 'label self.inflate_leave(ReturnCode::Ok); @@ -1774,7 +1787,7 @@ impl State<'_> { self.checksum = crate::ADLER32_INITIAL_VALUE as _; - self.mode = Mode::Type; + mode = Mode::Type; continue 'label; } @@ -1785,7 +1798,7 @@ impl State<'_> { self.bit_reader.init_bits(); - self.mode = Mode::Dict; + mode = Mode::Dict; continue 'label; } @@ -1808,7 +1821,7 @@ impl State<'_> { if self.wrap != 0 && self.gzip_flags != 0 { need_bits!(self, 32); if (self.wrap & 4) != 0 && self.bit_reader.hold() != self.total as u64 { - self.mode = Mode::Bad; + mode = Mode::Bad; break 'label self.bad("incorrect length check\0"); } @@ -1819,7 +1832,11 @@ impl State<'_> { break 'label ReturnCode::StreamEnd; } }; - } + }; + + self.mode = mode; + + ret } fn bad(&mut self, msg: &'static str) -> ReturnCode {