Skip to content

Commit 9e9c7ab

Browse files
Engine::internal_decode now returns DecodeSliceError
Implementations must now precisely, not conservatively, return an error when the output length is too small.
1 parent a8a60f4 commit 9e9c7ab

File tree

9 files changed

+237
-159
lines changed

9 files changed

+237
-159
lines changed

benches/benchmarks.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ fn do_encode_bench_slice(b: &mut Bencher, &size: &usize) {
102102
fn do_encode_bench_stream(b: &mut Bencher, &size: &usize) {
103103
let mut v: Vec<u8> = Vec::with_capacity(size);
104104
fill(&mut v);
105-
let mut buf = Vec::new();
105+
let mut buf = Vec::with_capacity(size * 2);
106106

107-
buf.reserve(size * 2);
108107
b.iter(|| {
109108
buf.clear();
110109
let mut stream_enc = write::EncoderWriter::new(&mut buf, &STANDARD);

src/decode.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ impl error::Error for DecodeError {}
5252
pub enum DecodeSliceError {
5353
/// A [DecodeError] occurred
5454
DecodeError(DecodeError),
55-
/// The provided slice _may_ be too small.
56-
///
57-
/// The check is conservative (assumes the last triplet of output bytes will all be needed).
55+
/// The provided slice is too small.
5856
OutputSliceTooSmall,
5957
}
6058

src/engine/general_purpose/decode.rs

+51-24
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
use crate::{
22
engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
3-
DecodeError, PAD_BYTE,
3+
DecodeError, DecodeSliceError, PAD_BYTE,
44
};
55

66
#[doc(hidden)]
77
pub struct GeneralPurposeEstimate {
8+
/// input len % 4
89
rem: usize,
9-
conservative_len: usize,
10+
conservative_decoded_len: usize,
1011
}
1112

1213
impl GeneralPurposeEstimate {
1314
pub(crate) fn new(encoded_len: usize) -> Self {
1415
let rem = encoded_len % 4;
1516
Self {
1617
rem,
17-
conservative_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
18+
conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
1819
}
1920
}
2021
}
2122

2223
impl DecodeEstimate for GeneralPurposeEstimate {
2324
fn decoded_len_estimate(&self) -> usize {
24-
self.conservative_len
25+
self.conservative_decoded_len
2526
}
2627
}
2728

@@ -38,25 +39,9 @@ pub(crate) fn decode_helper(
3839
decode_table: &[u8; 256],
3940
decode_allow_trailing_bits: bool,
4041
padding_mode: DecodePaddingMode,
41-
) -> Result<DecodeMetadata, DecodeError> {
42-
// detect a trailing invalid byte, like a newline, as a user convenience
43-
if estimate.rem == 1 {
44-
let last_byte = input[input.len() - 1];
45-
// exclude pad bytes; might be part of padding that extends from earlier in the input
46-
if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
47-
return Err(DecodeError::InvalidByte(input.len() - 1, last_byte));
48-
}
49-
}
50-
51-
// skip last quad, even if it's complete, as it may have padding
52-
let input_complete_nonterminal_quads_len = input
53-
.len()
54-
.saturating_sub(estimate.rem)
55-
// if rem was 0, subtract 4 to avoid padding
56-
.saturating_sub((estimate.rem == 0) as usize * 4);
57-
debug_assert!(
58-
input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
59-
);
42+
) -> Result<DecodeMetadata, DecodeSliceError> {
43+
let input_complete_nonterminal_quads_len =
44+
complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
6045

6146
const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
6247
const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
@@ -135,6 +120,48 @@ pub(crate) fn decode_helper(
135120
)
136121
}
137122

123+
/// Returns the length of complete quads, except for the last one, even if it is complete.
124+
///
125+
/// Returns an error if the output len is not big enough for decoding those complete quads, or if
126+
/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte.
127+
///
128+
/// - `input` is the base64 input
129+
/// - `input_len_rem` is input len % 4
130+
/// - `output_len` is the length of the output slice
131+
pub(crate) fn complete_quads_len(
132+
input: &[u8],
133+
input_len_rem: usize,
134+
output_len: usize,
135+
decode_table: &[u8; 256],
136+
) -> Result<usize, DecodeSliceError> {
137+
debug_assert!(input.len() % 4 == input_len_rem);
138+
139+
// detect a trailing invalid byte, like a newline, as a user convenience
140+
if input_len_rem == 1 {
141+
let last_byte = input[input.len() - 1];
142+
// exclude pad bytes; might be part of padding that extends from earlier in the input
143+
if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
144+
return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
145+
}
146+
};
147+
148+
// skip last quad, even if it's complete, as it may have padding
149+
let input_complete_nonterminal_quads_len = input
150+
.len()
151+
.saturating_sub(input_len_rem)
152+
// if rem was 0, subtract 4 to avoid padding
153+
.saturating_sub((input_len_rem == 0) as usize * 4);
154+
debug_assert!(
155+
input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
156+
);
157+
158+
// check that everything except the last quad handled by decode_suffix will fit
159+
if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
160+
return Err(DecodeSliceError::OutputSliceTooSmall);
161+
};
162+
Ok(input_complete_nonterminal_quads_len)
163+
}
164+
138165
/// Decode 8 bytes of input into 6 bytes of output.
139166
///
140167
/// `input` is the 8 bytes to decode.
@@ -321,7 +348,7 @@ mod tests {
321348
let len_128 = encoded_len as u128;
322349

323350
let estimate = GeneralPurposeEstimate::new(encoded_len);
324-
assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_len as u128);
351+
assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_decoded_len as u128);
325352
})
326353
}
327354
}

src/engine/general_purpose/decode_suffix.rs

+25-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode},
3-
DecodeError, PAD_BYTE,
3+
DecodeError, DecodeSliceError, PAD_BYTE,
44
};
55

66
/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided
@@ -16,11 +16,11 @@ pub(crate) fn decode_suffix(
1616
decode_table: &[u8; 256],
1717
decode_allow_trailing_bits: bool,
1818
padding_mode: DecodePaddingMode,
19-
) -> Result<DecodeMetadata, DecodeError> {
19+
) -> Result<DecodeMetadata, DecodeSliceError> {
2020
debug_assert!((input.len() - input_index) <= 4);
2121

22-
// Decode any leftovers that might not be a complete input chunk of 8 bytes.
23-
// Use a u64 as a stack-resident 8 byte buffer.
22+
// Decode any leftovers that might not be a complete input chunk of 4 bytes.
23+
// Use a u32 as a stack-resident 4 byte buffer.
2424
let mut morsels_in_leftover = 0;
2525
let mut padding_bytes_count = 0;
2626
// offset from input_index
@@ -44,22 +44,14 @@ pub(crate) fn decode_suffix(
4444
// may be treated as an error condition.
4545

4646
if leftover_index < 2 {
47-
// Check for case #2.
48-
let bad_padding_index = input_index
49-
+ if padding_bytes_count > 0 {
50-
// If we've already seen padding, report the first padding index.
51-
// This is to be consistent with the normal decode logic: it will report an
52-
// error on the first padding character (since it doesn't expect to see
53-
// anything but actual encoded data).
54-
// This could only happen if the padding started in the previous quad since
55-
// otherwise this case would have been hit at i == 4 if it was the same
56-
// quad.
57-
first_padding_offset
58-
} else {
59-
// haven't seen padding before, just use where we are now
60-
leftover_index
61-
};
62-
return Err(DecodeError::InvalidByte(bad_padding_index, b));
47+
// Check for error #2.
48+
// Either the previous byte was padding, in which case we would have already hit
49+
// this case, or it wasn't, in which case this is the first such error.
50+
debug_assert!(
51+
leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0)
52+
);
53+
let bad_padding_index = input_index + leftover_index;
54+
return Err(DecodeError::InvalidByte(bad_padding_index, b).into());
6355
}
6456

6557
if padding_bytes_count == 0 {
@@ -75,10 +67,9 @@ pub(crate) fn decode_suffix(
7567
// non-suffix '=' in trailing chunk either. Report error as first
7668
// erroneous padding.
7769
if padding_bytes_count > 0 {
78-
return Err(DecodeError::InvalidByte(
79-
input_index + first_padding_offset,
80-
PAD_BYTE,
81-
));
70+
return Err(
71+
DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(),
72+
);
8273
}
8374

8475
last_symbol = b;
@@ -87,7 +78,7 @@ pub(crate) fn decode_suffix(
8778
// Pack the leftovers from left to right.
8879
let morsel = decode_table[b as usize];
8980
if morsel == INVALID_VALUE {
90-
return Err(DecodeError::InvalidByte(input_index + leftover_index, b));
81+
return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into());
9182
}
9283

9384
morsels[morsels_in_leftover] = morsel;
@@ -97,24 +88,22 @@ pub(crate) fn decode_suffix(
9788
// If there was 1 trailing byte, and it was valid, and we got to this point without hitting
9889
// an invalid byte, now we can report invalid length
9990
if !input.is_empty() && morsels_in_leftover < 2 {
100-
return Err(DecodeError::InvalidLength(
101-
input_index + morsels_in_leftover,
102-
));
91+
return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into());
10392
}
10493

10594
match padding_mode {
10695
DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ }
10796
DecodePaddingMode::RequireCanonical => {
10897
// allow empty input
10998
if (padding_bytes_count + morsels_in_leftover) % 4 != 0 {
110-
return Err(DecodeError::InvalidPadding);
99+
return Err(DecodeError::InvalidPadding.into());
111100
}
112101
}
113102
DecodePaddingMode::RequireNone => {
114103
if padding_bytes_count > 0 {
115104
// check at the end to make sure we let the cases of padding that should be InvalidByte
116105
// get hit
117-
return Err(DecodeError::InvalidPadding);
106+
return Err(DecodeError::InvalidPadding.into());
118107
}
119108
}
120109
}
@@ -127,7 +116,7 @@ pub(crate) fn decode_suffix(
127116
// bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a
128117
// mask based on how many bits are used for just the canonical encoding, and optionally
129118
// error if any other bits are set. In the example of one encoded byte -> 2 symbols,
130-
// 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and
119+
// 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and
131120
// useless since there are no more symbols to provide the necessary 4 additional bits
132121
// to finish the second original byte.
133122

@@ -147,16 +136,18 @@ pub(crate) fn decode_suffix(
147136
return Err(DecodeError::InvalidLastSymbol(
148137
input_index + morsels_in_leftover - 1,
149138
last_symbol,
150-
));
139+
)
140+
.into());
151141
}
152142

153143
// Strangely, this approach benchmarks better than writing bytes one at a time,
154144
// or copy_from_slice into output.
155145
for _ in 0..leftover_bytes_to_append {
156146
let hi_byte = (leftover_num >> 24) as u8;
157147
leftover_num <<= 8;
158-
// TODO use checked writes
159-
output[output_index] = hi_byte;
148+
*output
149+
.get_mut(output_index)
150+
.ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte;
160151
output_index += 1;
161152
}
162153

src/engine/general_purpose/mod.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ use crate::{
33
alphabet,
44
alphabet::Alphabet,
55
engine::{Config, DecodeMetadata, DecodePaddingMode},
6-
DecodeError,
6+
DecodeSliceError,
77
};
88
use core::convert::TryInto;
99

10-
mod decode;
10+
pub(crate) mod decode;
1111
pub(crate) mod decode_suffix;
1212

1313
pub use decode::GeneralPurposeEstimate;
@@ -173,7 +173,7 @@ impl super::Engine for GeneralPurpose {
173173
input: &[u8],
174174
output: &mut [u8],
175175
estimate: Self::DecodeEstimate,
176-
) -> Result<DecodeMetadata, DecodeError> {
176+
) -> Result<DecodeMetadata, DecodeSliceError> {
177177
decode::decode_helper(
178178
input,
179179
estimate,

src/engine/mod.rs

+26-15
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,13 @@ pub trait Engine: Send + Sync {
8383
///
8484
/// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as
8585
/// errors unless the engine is configured otherwise.
86-
///
87-
/// # Panics
88-
///
89-
/// Panics if `output` is too small.
9086
#[doc(hidden)]
9187
fn internal_decode(
9288
&self,
9389
input: &[u8],
9490
output: &mut [u8],
9591
decode_estimate: Self::DecodeEstimate,
96-
) -> Result<DecodeMetadata, DecodeError>;
92+
) -> Result<DecodeMetadata, DecodeSliceError>;
9793

9894
/// Returns the config for this engine.
9995
fn config(&self) -> &Self::Config;
@@ -253,7 +249,13 @@ pub trait Engine: Send + Sync {
253249
let mut buffer = vec![0; estimate.decoded_len_estimate()];
254250

255251
let bytes_written = engine
256-
.internal_decode(input_bytes, &mut buffer, estimate)?
252+
.internal_decode(input_bytes, &mut buffer, estimate)
253+
.map_err(|e| match e {
254+
DecodeSliceError::DecodeError(e) => e,
255+
DecodeSliceError::OutputSliceTooSmall => {
256+
unreachable!("Vec is sized conservatively")
257+
}
258+
})?
257259
.decoded_len;
258260

259261
buffer.truncate(bytes_written);
@@ -318,7 +320,13 @@ pub trait Engine: Send + Sync {
318320
let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
319321

320322
let bytes_written = engine
321-
.internal_decode(input_bytes, buffer_slice, estimate)?
323+
.internal_decode(input_bytes, buffer_slice, estimate)
324+
.map_err(|e| match e {
325+
DecodeSliceError::DecodeError(e) => e,
326+
DecodeSliceError::OutputSliceTooSmall => {
327+
unreachable!("Vec is sized conservatively")
328+
}
329+
})?
322330
.decoded_len;
323331

324332
buffer.truncate(starting_output_len + bytes_written);
@@ -354,15 +362,12 @@ pub trait Engine: Send + Sync {
354362
where
355363
E: Engine + ?Sized,
356364
{
357-
let estimate = engine.internal_decoded_len_estimate(input_bytes.len());
358-
359-
if output.len() < estimate.decoded_len_estimate() {
360-
return Err(DecodeSliceError::OutputSliceTooSmall);
361-
}
362-
363365
engine
364-
.internal_decode(input_bytes, output, estimate)
365-
.map_err(|e| e.into())
366+
.internal_decode(
367+
input_bytes,
368+
output,
369+
engine.internal_decoded_len_estimate(input_bytes.len()),
370+
)
366371
.map(|dm| dm.decoded_len)
367372
}
368373

@@ -400,6 +405,12 @@ pub trait Engine: Send + Sync {
400405
engine.internal_decoded_len_estimate(input_bytes.len()),
401406
)
402407
.map(|dm| dm.decoded_len)
408+
.map_err(|e| match e {
409+
DecodeSliceError::DecodeError(e) => e,
410+
DecodeSliceError::OutputSliceTooSmall => {
411+
panic!("Output slice is too small")
412+
}
413+
})
403414
}
404415

405416
inner(self, input.as_ref(), output)

0 commit comments

Comments
 (0)