diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6936990 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +**/*.rs.bk +Cargo.lock diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..16e39f6 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +branches: + only: + # This is where pull requests from "bors r+" are built. + - staging + # This is where pull requests from "bors try" are built. + - trying + # Not really necessary, just to get a green badge on “master” + - master +language: rust + +jobs: + include: + - rust: stable + - rust: beta + - rust: nightly + env: FEATURES=--features=nightly +script: + - cargo test --no-default-features + - cargo test $FEATURES diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..21304f0 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "b64-ct" +version = "0.1.0" +authors = ["Fortanix, Inc."] +license = "MPL-2.0" +edition = "2018" +description = """ +Fast and secure Base64 encoding/decoding. + +This crate provides an implementation of Base64 encoding/decoding that is +designed to be resistant against software side-channel attacks (such as timing +& cache attacks), see the documentation for details. On certain platforms it +also uses SIMD making it very fast. This makes it suitable for e.g. decoding +cryptographic private keys in PEM format. + +The API is very similar to the base64 implementation in the old rustc-serialize +crate, making it easy to use in existing projects. +""" +repository = "https://github.com/fortanix/b64-ct/" +keywords = ["base64", "constant-time"] +categories = ["cryptography", "encoding", "no-std"] +readme = "README.md" + +[features] +default = ["std"] +std = [] +nightly = [] # Used only for testing + +[dev-dependencies] +rand = "0.6" +paste = "0.1" diff --git a/LICENSE b/LICENSE index 14e2f77..7ecdf26 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,65 @@ +This project as a whole is licensed under the terms of the Mozilla Public +License, Version 2.0, see below. This project includes code written by The Rust +Project Developers and Wojciech Muła, license terms below. + +---------------------------------------------------------------------------- + +Copyright (c) 2014 The Rust Project Developers + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + +---------------------------------------------------------------------------- + +Copyright (c) 2015-2018, Wojciech Muła +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------------- + Mozilla Public License Version 2.0 ================================== diff --git a/README.md b/README.md index 7d0ab40..7a476aa 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,32 @@ +# Fast and secure Base64 encoding/decoding + +This crate provides an implementation of Base64 encoding/decoding that is +designed to be resistant against software side-channel attacks (such as timing +& cache attacks), see the [documentation] for details. On certain platforms it +also uses SIMD making it very fast. This makes it suitable for e.g. decoding +cryptographic private keys in PEM format. + +The API is very similar to the base64 implementation in the old rustc-serialize +crate, making it easy to use in existing projects. + +[documentation]: https://docs.rs/b64-ct + +# Implementation + +Depending on the runtime CPU architecture, this crate uses different +implementations with different security properties. + +* x86 with AVX2: All lookup tables are implemented with SIMD + instructions. No secret-dependent memory accceses. +* Other platforms: Lookups are limited to 64-byte aligned lookup tables. On + platforms with 64-byte cache lines this may be sufficient to prevent + certain cache side-channel attacks. However, it's known that this is [not + sufficient for all platforms]. + +We graciously welcome contributed support for other platforms! + +[not sufficient on some platforms]: https://ts.data61.csiro.au/projects/TS/cachebleed/ + # Contributing We gratefully accept bug reports and contributions from the community. diff --git a/bors.toml b/bors.toml new file mode 100644 index 0000000..ca08e81 --- /dev/null +++ b/bors.toml @@ -0,0 +1,3 @@ +status = [ + "continuous-integration/travis-ci/push", +] diff --git a/src/avx2.rs b/src/avx2.rs new file mode 100644 index 0000000..3c487e7 --- /dev/null +++ b/src/avx2.rs @@ -0,0 +1,39 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use core::arch::x86_64::*; + +#[rustfmt::skip] +pub(crate) unsafe fn dup_mm_setr_epi8(e: [i8; 16]) -> __m256i { + _mm256_setr_epi8( + e[0x0], e[0x1], e[0x2], e[0x3], e[0x4], e[0x5], e[0x6], e[0x7], + e[0x8], e[0x9], e[0xa], e[0xb], e[0xc], e[0xd], e[0xe], e[0xf], + e[0x0], e[0x1], e[0x2], e[0x3], e[0x4], e[0x5], e[0x6], e[0x7], + e[0x8], e[0x9], e[0xa], e[0xb], e[0xc], e[0xd], e[0xe], e[0xf], + ) +} + +#[rustfmt::skip] +pub(crate) unsafe fn dup_mm_setr_epu8(e: [u8; 16]) -> __m256i { + _mm256_setr_epi8( + e[0x0] as _, e[0x1] as _, e[0x2] as _, e[0x3] as _, e[0x4] as _, e[0x5] as _, e[0x6] as _, e[0x7] as _, + e[0x8] as _, e[0x9] as _, e[0xa] as _, e[0xb] as _, e[0xc] as _, e[0xd] as _, e[0xe] as _, e[0xf] as _, + e[0x0] as _, e[0x1] as _, e[0x2] as _, e[0x3] as _, e[0x4] as _, e[0x5] as _, e[0x6] as _, e[0x7] as _, + e[0x8] as _, e[0x9] as _, e[0xa] as _, e[0xb] as _, e[0xc] as _, e[0xd] as _, e[0xe] as _, e[0xf] as _, + ) +} + +pub(crate) unsafe fn _mm256_not_si256(i: __m256i) -> __m256i { + _mm256_xor_si256(i, _mm256_set1_epi8(!0)) +} + +pub(crate) unsafe fn array_as_m256i(v: [u8; 32]) -> __m256i { + core::mem::transmute(v) +} + +pub(crate) unsafe fn m256i_as_array(v: __m256i) -> [u8; 32] { + core::mem::transmute(v) +} diff --git a/src/decode/avx2.rs b/src/decode/avx2.rs new file mode 100644 index 0000000..0071b14 --- /dev/null +++ b/src/decode/avx2.rs @@ -0,0 +1,241 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use core::arch::x86_64::*; + +use crate::avx2::*; + +/// # Safety +/// The caller should ensure the requisite CPU features are enabled. +#[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")] +unsafe fn decode_avx2(input: __m256i) -> (__m256i, u32, u32) { + // Step 0. Split input bytes into nibbles. + let higher_nibble = _mm256_and_si256(_mm256_srli_epi16(input, 4), _mm256_set1_epi8(0x0f)); + let lower_nibble = _mm256_and_si256(input, _mm256_set1_epi8(0x0f)); + + // Step 1. Find invalid characters. Steps 2 & 3 will compute invalid 6-bit + // values for invalid characters. The result of the computation should only + // be used if no invalid characters are found. + + // This table contains 128 bits, one bit for each of the lower 128 ASCII + // characters. A set bit indicates that the character is in the base64 + // character set (the character is valid) or the character is considered + // ASCII whitespace. This table is indexed by ASCII low nibble. + #[rustfmt::skip] + let row_lut = dup_mm_setr_epu8([ + 0b1010_1100, 0b1111_1000, 0b1111_1000, 0b1111_1000, + 0b1111_1000, 0b1111_1000, 0b1111_1000, 0b1111_1000, + 0b1111_1000, 0b1111_1001, 0b1111_0001, 0b0101_0100, + 0b0101_0001, 0b0101_0101, 0b0101_0000, 0b0111_0100, + ]); + + // This table contains column offsets (within a byte) for the table above. + // This table is indexed by ASCII high nibble. + #[rustfmt::skip] + let column_lut = dup_mm_setr_epu8([ + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0, 0, 0, 0, 0, 0, 0, 0, + ]); + + // Lookup table row + let row = _mm256_shuffle_epi8(row_lut, lower_nibble); + // Lookup column offset + let column = _mm256_shuffle_epi8(column_lut, higher_nibble); + // Lookup valid characters + let valid = _mm256_and_si256(row, column); + // Compute invalid character mask + let non_match = _mm256_cmpeq_epi8(valid, _mm256_setzero_si256()); + // Transform mask to u32 + let invalid_mask = _mm256_movemask_epi8(non_match); + + // Step 2. Numbers & letters: compute 6-bit value for the 3 different + // ranges by simply adjusting the ASCII value. + + // This table contains the offsets for the alphanumerical ASCII ranges. + // This table is indexed by ASCII high nibble. + #[rustfmt::skip] + let shift_lut = dup_mm_setr_epi8([ + 0, 0, 0, + // '0' through '9' + 4, + // 'A' through 'Z' + -65, -65, + // 'a' through 'z' + -71, -71, + 0, 0, 0, 0, 0, 0, 0, 0, + ]); + + // Get offset + let shift = _mm256_shuffle_epi8(shift_lut, higher_nibble); + // Compute 6-bit value + let shifted = _mm256_add_epi8(input, shift); + + // Step 3. Special characters: lookup 6-bit value by looking it up in a + // table. + + // This table specifies the ASCII ranges that contain valid special + // characters. This table is indexed by ASCII high nibble. + #[rustfmt::skip] + let spcrange_lut = dup_mm_setr_epu8([ + 0, 0, 0xff, 0, 0, 0xff, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]); + + // This table specifies the (inverted) 6-bit values for the special + // characters. The values in this table act as both a value and a blend + // mask. This table is indexed by the difference between ASCII low and high + // nibble. + #[rustfmt::skip] + let spcchar_lut = dup_mm_setr_epu8([ + 0, 0, 0, 0, 0, 0, 0, 0, + // '+', '_', '-', '/' + 0, !62, !63, !62, 0, !63, 0, 0, + ]); + + // Check if character is in the range for special characters + let sel_range = _mm256_shuffle_epi8(spcrange_lut, higher_nibble); + // Compute difference between ASCII low and high nibble + let lo_sub_hi = _mm256_sub_epi8(lower_nibble, higher_nibble); + // Lookup special character 6-bit value + let specials = _mm256_shuffle_epi8(spcchar_lut, lo_sub_hi); + // Combine blend masks from range and value + let sel_spec = _mm256_and_si256(sel_range, specials); + + // Combine results of step 1 and step 2 + let result = _mm256_blendv_epi8(shifted, _mm256_not_si256(specials), sel_spec); + + // Step 4. Compute mask for valid non-whitespace bytes. The mask will be + // used to copy only relevant bytes into the output. + + // This table specifies the character ranges which should be decoded. The + // format is a range table for the PCMPESTRM instruction. + #[rustfmt::skip] + let valid_nonws_set = _mm_setr_epi8( + b'A' as _, b'Z' as _, + b'a' as _, b'z' as _, + b'0' as _, b'9' as _, + b'+' as _, b'+' as _, + b'/' as _, b'/' as _, + b'-' as _, b'-' as _, + b'_' as _, b'_' as _, + 0, 0, + ); + + // Split input into 128-bit values + let lane0 = _mm256_extracti128_si256(input, 0); + let lane1 = _mm256_extracti128_si256(input, 1); + // Compute bitmask for each 128-bit value + const CMP_FLAGS: i32 = _SIDD_UBYTE_OPS | _SIDD_CMP_RANGES | _SIDD_BIT_MASK; + let mask0 = _mm_cmpestrm(valid_nonws_set, 14, lane0, 16, CMP_FLAGS); + let mask1 = _mm_cmpestrm(valid_nonws_set, 14, lane1, 16, CMP_FLAGS); + // Combine bitmasks into integer value + let valid_mask = + _mm_extract_epi16(mask0, 0) as u32 | ((_mm_extract_epi16(mask1, 0) as u32) << 16); + + (result, invalid_mask as _, valid_mask as _) +} + +/// # Safety +/// The caller should ensure the requisite CPU features are enabled. +#[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")] +unsafe fn decode_block(block: &mut ::Block) -> super::BlockResult { + let input = array_as_m256i(*block); + + let (unpacked, invalid_mask, mut valid_mask) = decode_avx2(input); + + let unpacked = m256i_as_array(unpacked); + + let first_invalid = match invalid_mask.trailing_zeros() { + 32 => None, + v => Some(v as _), + }; + let out_length = valid_mask.count_ones() as _; + + let mut out_iter = block.iter_mut(); + // TODO: Optimize loop (https://github.com/fortanix/b64-ct/issues/2) + for &val in unpacked.iter() { + if (valid_mask & 1) == 1 { + *out_iter.next().unwrap() = val; + } + valid_mask >>= 1; + } + + super::BlockResult { + out_length, + first_invalid, + } +} + +/// # Safety +/// The caller should ensure the requisite CPU features are enabled. +#[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")] +unsafe fn pack_block(input: &::Input, output: &mut [u8]) { + assert_eq!(output.len(), ::OUT_BUF_LEN); + + let unpacked = array_as_m256i(*input); + + // Pack 32× 6-bit values into 16× 12-bit values + let packed1 = _mm256_maddubs_epi16(unpacked, _mm256_set1_epi16(0x0140)); + // Pack 16× 12-bit values into 8× 3-byte values + let packed2 = _mm256_madd_epi16(packed1, _mm256_set1_epi32(0x00011000)); + // Pack 8× 3-byte values into 2× 12-byte values + #[rustfmt::skip] + let packed3 = _mm256_shuffle_epi8(packed2, dup_mm_setr_epu8([ + 2, 1, 0, + 6, 5, 4, + 10, 9, 8, + 14, 13, 12, + 0xff, 0xff, 0xff, 0xff, + ])); + + _mm_storeu_si128( + output.as_mut_ptr() as _, + _mm256_extracti128_si256(packed3, 0), + ); + _mm_storeu_si128( + output.as_mut_ptr().offset(12) as _, + _mm256_extracti128_si256(packed3, 1), + ); +} + +#[derive(Copy, Clone)] +pub(super) struct Avx2 { + _private: (), +} + +impl Avx2 { + /// # Safety + /// The caller should ensure the requisite CPU features are enabled. + #[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")] + pub(super) unsafe fn new() -> Avx2 { + Avx2 { _private: () } + } +} + +impl super::Decoder for Avx2 { + type Block = [u8; 32]; + + #[inline] + fn decode_block(self, block: &mut Self::Block) -> super::BlockResult { + // safe: `self` was given as a witness that the features are available + unsafe { decode_block(block) } + } + + #[inline(always)] + fn zero_block() -> Self::Block { + [b' '; 32] + } +} + +impl super::Packer for Avx2 { + type Input = [u8; 32]; + const OUT_BUF_LEN: usize = 28; + + fn pack_block(self, input: &Self::Input, output: &mut [u8]) { + // safe: `self` was given as a witness that the features are available + unsafe { pack_block(input, output) } + } +} diff --git a/src/decode/lut_align64.rs b/src/decode/lut_align64.rs new file mode 100644 index 0000000..e8409e5 --- /dev/null +++ b/src/decode/lut_align64.rs @@ -0,0 +1,191 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +const INVALID_VALUE: u8 = 0x80; +const SPACE_VALUE: u8 = 0x40; + +use crate::lut_align64::CacheLineLut; + +static LUT1: CacheLineLut = CacheLineLut([ + INVALID_VALUE, // input 0 (0x0) + INVALID_VALUE, // input 1 (0x1) + INVALID_VALUE, // input 2 (0x2) + INVALID_VALUE, // input 3 (0x3) + INVALID_VALUE, // input 4 (0x4) + INVALID_VALUE, // input 5 (0x5) + INVALID_VALUE, // input 6 (0x6) + INVALID_VALUE, // input 7 (0x7) + INVALID_VALUE, // input 8 (0x8) + SPACE_VALUE, // input 9 (0x9) + SPACE_VALUE, // input 10 (0xA) + INVALID_VALUE, // input 11 (0xB) + SPACE_VALUE, // input 12 (0xC) + SPACE_VALUE, // input 13 (0xD) + INVALID_VALUE, // input 14 (0xE) + INVALID_VALUE, // input 15 (0xF) + INVALID_VALUE, // input 16 (0x10) + INVALID_VALUE, // input 17 (0x11) + INVALID_VALUE, // input 18 (0x12) + INVALID_VALUE, // input 19 (0x13) + INVALID_VALUE, // input 20 (0x14) + INVALID_VALUE, // input 21 (0x15) + INVALID_VALUE, // input 22 (0x16) + INVALID_VALUE, // input 23 (0x17) + INVALID_VALUE, // input 24 (0x18) + INVALID_VALUE, // input 25 (0x19) + INVALID_VALUE, // input 26 (0x1A) + INVALID_VALUE, // input 27 (0x1B) + INVALID_VALUE, // input 28 (0x1C) + INVALID_VALUE, // input 29 (0x1D) + INVALID_VALUE, // input 30 (0x1E) + INVALID_VALUE, // input 31 (0x1F) + SPACE_VALUE, // input 32 (0x20) + INVALID_VALUE, // input 33 (0x21) + INVALID_VALUE, // input 34 (0x22) + INVALID_VALUE, // input 35 (0x23) + INVALID_VALUE, // input 36 (0x24) + INVALID_VALUE, // input 37 (0x25) + INVALID_VALUE, // input 38 (0x26) + INVALID_VALUE, // input 39 (0x27) + INVALID_VALUE, // input 40 (0x28) + INVALID_VALUE, // input 41 (0x29) + INVALID_VALUE, // input 42 (0x2A) + 62, // input 43 (0x2B char '+') => 62 (0x3E) + INVALID_VALUE, // input 44 (0x2C) + 62, // input 45 (0x2D char '-') => 62 (0x3E) + INVALID_VALUE, // input 46 (0x2E) + 63, // input 47 (0x2F char '/') => 63 (0x3F) + 52, // input 48 (0x30 char '0') => 52 (0x34) + 53, // input 49 (0x31 char '1') => 53 (0x35) + 54, // input 50 (0x32 char '2') => 54 (0x36) + 55, // input 51 (0x33 char '3') => 55 (0x37) + 56, // input 52 (0x34 char '4') => 56 (0x38) + 57, // input 53 (0x35 char '5') => 57 (0x39) + 58, // input 54 (0x36 char '6') => 58 (0x3A) + 59, // input 55 (0x37 char '7') => 59 (0x3B) + 60, // input 56 (0x38 char '8') => 60 (0x3C) + 61, // input 57 (0x39 char '9') => 61 (0x3D) + INVALID_VALUE, // input 58 (0x3A) + INVALID_VALUE, // input 59 (0x3B) + INVALID_VALUE, // input 60 (0x3C) + INVALID_VALUE, // input 61 (0x3D) + INVALID_VALUE, // input 62 (0x3E) + INVALID_VALUE, // input 63 (0x3F) +]); + +static LUT2: CacheLineLut = CacheLineLut([ + INVALID_VALUE, // input 64 (0x40) + 0, // input 65 (0x41 char 'A') => 0 (0x0) + 1, // input 66 (0x42 char 'B') => 1 (0x1) + 2, // input 67 (0x43 char 'C') => 2 (0x2) + 3, // input 68 (0x44 char 'D') => 3 (0x3) + 4, // input 69 (0x45 char 'E') => 4 (0x4) + 5, // input 70 (0x46 char 'F') => 5 (0x5) + 6, // input 71 (0x47 char 'G') => 6 (0x6) + 7, // input 72 (0x48 char 'H') => 7 (0x7) + 8, // input 73 (0x49 char 'I') => 8 (0x8) + 9, // input 74 (0x4A char 'J') => 9 (0x9) + 10, // input 75 (0x4B char 'K') => 10 (0xA) + 11, // input 76 (0x4C char 'L') => 11 (0xB) + 12, // input 77 (0x4D char 'M') => 12 (0xC) + 13, // input 78 (0x4E char 'N') => 13 (0xD) + 14, // input 79 (0x4F char 'O') => 14 (0xE) + 15, // input 80 (0x50 char 'P') => 15 (0xF) + 16, // input 81 (0x51 char 'Q') => 16 (0x10) + 17, // input 82 (0x52 char 'R') => 17 (0x11) + 18, // input 83 (0x53 char 'S') => 18 (0x12) + 19, // input 84 (0x54 char 'T') => 19 (0x13) + 20, // input 85 (0x55 char 'U') => 20 (0x14) + 21, // input 86 (0x56 char 'V') => 21 (0x15) + 22, // input 87 (0x57 char 'W') => 22 (0x16) + 23, // input 88 (0x58 char 'X') => 23 (0x17) + 24, // input 89 (0x59 char 'Y') => 24 (0x18) + 25, // input 90 (0x5A char 'Z') => 25 (0x19) + INVALID_VALUE, // input 91 (0x5B) + INVALID_VALUE, // input 92 (0x5C) + INVALID_VALUE, // input 93 (0x5D) + INVALID_VALUE, // input 94 (0x5E) + 63, // input 95 (0x5F char '_') => 63 (0x3F) + INVALID_VALUE, // input 96 (0x60) + 26, // input 97 (0x61 char 'a') => 26 (0x1A) + 27, // input 98 (0x62 char 'b') => 27 (0x1B) + 28, // input 99 (0x63 char 'c') => 28 (0x1C) + 29, // input 100 (0x64 char 'd') => 29 (0x1D) + 30, // input 101 (0x65 char 'e') => 30 (0x1E) + 31, // input 102 (0x66 char 'f') => 31 (0x1F) + 32, // input 103 (0x67 char 'g') => 32 (0x20) + 33, // input 104 (0x68 char 'h') => 33 (0x21) + 34, // input 105 (0x69 char 'i') => 34 (0x22) + 35, // input 106 (0x6A char 'j') => 35 (0x23) + 36, // input 107 (0x6B char 'k') => 36 (0x24) + 37, // input 108 (0x6C char 'l') => 37 (0x25) + 38, // input 109 (0x6D char 'm') => 38 (0x26) + 39, // input 110 (0x6E char 'n') => 39 (0x27) + 40, // input 111 (0x6F char 'o') => 40 (0x28) + 41, // input 112 (0x70 char 'p') => 41 (0x29) + 42, // input 113 (0x71 char 'q') => 42 (0x2A) + 43, // input 114 (0x72 char 'r') => 43 (0x2B) + 44, // input 115 (0x73 char 's') => 44 (0x2C) + 45, // input 116 (0x74 char 't') => 45 (0x2D) + 46, // input 117 (0x75 char 'u') => 46 (0x2E) + 47, // input 118 (0x76 char 'v') => 47 (0x2F) + 48, // input 119 (0x77 char 'w') => 48 (0x30) + 49, // input 120 (0x78 char 'x') => 49 (0x31) + 50, // input 121 (0x79 char 'y') => 50 (0x32) + 51, // input 122 (0x7A char 'z') => 51 (0x33) + INVALID_VALUE, // input 123 (0x7B) + INVALID_VALUE, // input 124 (0x7C) + INVALID_VALUE, // input 125 (0x7D) + INVALID_VALUE, // input 126 (0x7E) + INVALID_VALUE, // input 127 (0x7F) +]); + +fn decode64(b: u8) -> (u8, bool, bool) { + let idx = (b % 64) as usize; + + /* This is basically what + * ``` + * let mask = match b & 0xc0 { + * 0x00 => 0, + * 0x40 => 0xff, + * _ => unsafe { std::mem::MaybeUninit::uninit().assume_init() } + * } + * ``` + * compiles into. + */ + let mask = (-(((b & 0xc0) == 0x40) as bool as i8)) as u8; + + let looked_up = ((!mask) & LUT1.0[idx]) | (mask & LUT2.0[idx]); + ( + looked_up, + ((b | looked_up) as i8).is_negative(), + (looked_up & SPACE_VALUE) == SPACE_VALUE, + ) +} + +#[derive(Copy, Clone)] +pub(super) struct LutAlign64; + +impl super::Decoder for LutAlign64 { + type Block = [u8; 1]; + + #[inline] + fn decode_block(self, block: &mut Self::Block) -> super::BlockResult { + let (a, invalid, space) = decode64(block[0]); + + block[0] = a; + + super::BlockResult { + out_length: if space | invalid { 0 } else { 1 }, + first_invalid: if invalid { Some(0) } else { None }, + } + } + + #[inline(always)] + fn zero_block() -> Self::Block { + [b' '; 1] + } +} diff --git a/src/decode/mod.rs b/src/decode/mod.rs new file mode 100644 index 0000000..9f4b152 --- /dev/null +++ b/src/decode/mod.rs @@ -0,0 +1,390 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +mod avx2; +mod lut_align64; + +use alloc::vec::Vec; +use core::cmp; + +#[must_use] +struct BlockResult { + out_length: u8, + first_invalid: Option, +} + +/// Errors that can occur when decoding a base64 encoded string +#[derive(Debug, Clone, Copy)] +pub enum Error { + /// The input had an invalid length. + InvalidLength, + /// A trailer was found, but it wasn't the right length. + InvalidTrailer, + /// The input contained a character (at the given index) not part of the + /// base64 format. + InvalidCharacter(usize), +} + +trait Decoder: Copy { + type Block: AsRef<[u8]> + AsMut<[u8]>; + + fn decode_block(self, block: &mut Self::Block) -> BlockResult; + fn zero_block() -> Self::Block; +} + +trait Packer: Copy { + type Input: AsRef<[u8]> + AsMut<[u8]> + Default; + const OUT_BUF_LEN: usize; + + /// The caller should pass `output` as a slice with length `OUT_BUF_LEN`. + fn pack_block(self, input: &Self::Input, output: &mut [u8]); +} + +#[derive(Copy, Clone)] +struct Simple; + +impl Packer for Simple { + type Input = [u8; 4]; + const OUT_BUF_LEN: usize = 3; + + #[inline] + fn pack_block(self, input: &Self::Input, output: &mut [u8]) { + output[0] = (input[0] << 2) | (input[1] >> 4); + output[1] = (input[1] << 4) | (input[2] >> 2); + output[2] = (input[2] << 6) | (input[3] >> 0); + } +} + +struct PackState { + packer: P, + cache: P::Input, + pos: usize, +} + +impl PackState

{ + fn extend(&mut self, mut input: &[u8], out: &mut Vec) { + while !input.is_empty() { + let (_, cache_end) = self.cache.as_mut().split_at_mut(self.pos); + let (input_start, input_rest) = input.split_at(cmp::min(input.len(), cache_end.len())); + input = input_rest; + cache_end[..input_start.len()].copy_from_slice(input_start); + if input_start.len() != cache_end.len() { + self.pos += input_start.len(); + } else { + let out_start = out.len(); + out.resize(out_start + P::OUT_BUF_LEN, 0); + self.packer.pack_block(&self.cache, &mut out[out_start..]); + out.truncate(out_start + (core::mem::size_of::() / 4 * 3)); + self.pos = 0; + } + } + } + + fn flush(&mut self, out: &mut Vec, trailer_length: Option) -> Result<(), Error> { + if self.pos % 4 == 1 { + return Err(Error::InvalidLength); + } + + if let Some(trailer_length) = trailer_length { + if (self.pos + trailer_length) % 4 != 0 { + return Err(Error::InvalidTrailer); + } + } + + self.cache.as_mut()[self.pos] = 0; + let out_start = out.len(); + out.resize(out.len() + P::OUT_BUF_LEN, 0); + self.packer.pack_block(&self.cache, &mut out[out_start..]); + out.truncate(out_start + (self.pos * 3 / 4)); + Ok(()) + } +} + +fn decode64(input: &[u8], decoder: D, packer: P) -> Result, Error> { + if input.is_empty() { + return Ok(Vec::new()); + } + + let p_in_len = core::mem::size_of::(); + let p_out_len = p_in_len / 4 * 3; + let cap = + crate::misc::div_roundup(input.len(), p_in_len) * p_out_len - p_out_len + P::OUT_BUF_LEN; + let mut out = Vec::with_capacity(cap); + + let mut packer = PackState::

{ + packer, + cache: P::Input::default(), + pos: 0, + }; + + let mut trailer_length = None; + for (chunk, chunk_start) in input + .chunks(core::mem::size_of::()) + .zip((0..).step_by(core::mem::size_of::())) + { + let mut block = D::zero_block(); + block.as_mut()[..chunk.len()].copy_from_slice(chunk); + let result = decoder.decode_block(&mut block); + + if let Some(idx) = result.first_invalid { + let idx = idx as usize; + if input[chunk_start + idx] == b'=' { + let rest_start = chunk_start + idx + 1; + let rest = &input[rest_start..]; + let mut iter = rest + .iter() + .enumerate() + .filter(|(_, c)| !c.is_ascii_whitespace()); + trailer_length = match (iter.next(), iter.next()) { + (None, _) => Some(1), + (Some((_, b'=')), None) => Some(2), + (Some((_, b'=')), Some((i, _))) | (Some((i, _)), _) => { + return Err(Error::InvalidCharacter(rest_start + i)) + } + }; + } else { + return Err(Error::InvalidCharacter(chunk_start + idx)); + } + } + + packer.extend(&block.as_ref()[..(result.out_length as _)], &mut out); + + if trailer_length.is_some() { + break; + } + } + + packer.flush(&mut out, trailer_length)?; + + Ok(out) +} + +pub(super) fn decode64_arch(input: &[u8]) -> Result, Error> { + unsafe { + if is_x86_feature_detected!("avx2") + && is_x86_feature_detected!("bmi1") + && is_x86_feature_detected!("sse4.2") + && is_x86_feature_detected!("popcnt") + { + let avx2 = avx2::Avx2::new(); + return decode64(input, avx2, avx2); + } + } + decode64(input, lut_align64::LutAlign64, Simple) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test_support::rand_base64_size; + + pub(super) fn test_avx2() -> avx2::Avx2 { + unsafe { avx2::Avx2::new() } + } + + generate_tests![ + decoders: { + avx2, test_avx2(); + lut_align64, lut_align64::LutAlign64; + }, + packers

: { + avx2, test_avx2(); + simple, Simple; + }, + tests: { + decode, + decode_equivalency, + decode_error, + cmp_rand_1kb, + whitespace_skipped, + all_bytes, + }, + ]; + + fn decode(decoder: D, packer: P) { + static DECODE_TESTS: &[(&[u8], &[u8])] = &[ + // basic tests (from rustc-serialize) + (b"", b""), + (b"Zg==", b"f"), + (b"Zm8=", b"fo"), + (b"Zm9v", b"foo"), + (b"Zm9vYg==", b"foob"), + (b"Zm9vYmE=", b"fooba"), + (b"Zm9vYmFy", b"foobar"), + // with newlines (from rustc-serialize) + (b"Zm9v\r\nYmFy", b"foobar"), + (b"Zm9vYg==\r\n", b"foob"), + (b"Zm9v\nYmFy", b"foobar"), + (b"Zm9vYg==\n", b"foob"), + // white space in trailer + (b"Zm9vYg = = ", b"foob"), + ]; + + for (input, expected) in DECODE_TESTS { + let output = decode64(input, decoder, packer).unwrap(); + if &output != expected { + panic!( + "Test failed. Expected specific output. \n\nInput: {}\nOutput: {:02x?}\nExpected output:{:02x?}\n\n", + std::str::from_utf8(input).unwrap(), + output, + expected + ); + } + } + } + + fn decode_equivalency(decoder: D, packer: P) { + static DECODE_EQUIVALENCY_TESTS: &[(&[u8], &[u8])] = &[ + // url safe test (from rustc-serialize) + (b"-_8", b"+/8="), + ]; + + for (input1, input2) in DECODE_EQUIVALENCY_TESTS { + let output1 = decode64(input1, decoder, packer).unwrap(); + let output2 = decode64(input2, decoder, packer).unwrap(); + if output1 != output2 { + panic!( + "Test failed. Expected same output.\n\nInput 1: {}\nInput 2: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n", + std::str::from_utf8(input1).unwrap(), + std::str::from_utf8(input2).unwrap(), + output1, + output2 + ); + } + } + } + + fn decode_error(decoder: D, packer: P) { + #[rustfmt::skip] + static DECODE_ERROR_TESTS: &[&[u8]] = &[ + // invalid chars (from rustc-serialize) + b"Zm$=", + b"Zg==$", + // invalid padding (from rustc-serialize) + b"Z===", + ]; + + for input in DECODE_ERROR_TESTS { + if decode64(input, decoder, packer).is_ok() { + panic!( + "Test failed. Expected error.\n\nInput: {}\n\n", + std::str::from_utf8(input).unwrap(), + ); + } + } + } + + fn cmp_rand_1kb(decoder: D, packer: P) { + let input = rand_base64_size(1024); + + let output1 = decode64(&input, decoder, packer).unwrap(); + let output2 = decode64(&input, lut_align64::LutAlign64, Simple).unwrap(); + if output1 != output2 { + panic!( + "Test failed. Expected same output.\n\nInput: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n", + std::str::from_utf8(&input).unwrap(), + output1, + output2 + ); + } + } + + fn whitespace_skipped(decoder: D, packer: P) { + let input1 = rand_base64_size(32); + use core::iter::once; + let input2 = input1 + .iter() + .flat_map(|&c| once(c).chain(once(b' '))) + .collect::>(); + + let output1 = decode64(&input1, decoder, packer).unwrap(); + let output2 = decode64(&input2, decoder, packer).unwrap(); + if output1 != output2 { + panic!( + "Test failed. Expected same output.\n\nInput 1: {}\nInput 2: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n", + std::str::from_utf8(&input1).unwrap(), + std::str::from_utf8(&input2).unwrap(), + output1, + output2 + ); + } + } + + fn all_bytes(decoder: D, packer: P) { + let mut set = std::vec![Err(()); 256]; + for (i, &b) in crate::misc::LUT_STANDARD.iter().enumerate() { + set[b as usize] = Ok(Some(i as u8)); + } + // add URL-safe set + set[b'-' as usize] = Ok(Some(62)); + set[b'_' as usize] = Ok(Some(63)); + // add whitespace + set[b' ' as usize] = Ok(None); + set[b'\n' as usize] = Ok(None); + set[b'\t' as usize] = Ok(None); + set[b'\r' as usize] = Ok(None); + set[0x0c] = Ok(None); + + for (i, &expected) in set.iter().enumerate() { + let output = match decode64(&[i as u8, i as u8], decoder, packer) + .as_ref() + .map(|v| &v[..]) + { + Ok(&[]) => Ok(None), + Ok(&[v]) => Ok(Some(v >> 2)), + Ok(_) => panic!("Result is more than 1 byte long"), + Err(_) => Err(()), + }; + assert_eq!(output, expected); + } + } +} + +#[cfg(all(test, feature = "nightly"))] +mod benches { + use super::{tests::test_avx2, *}; + + use test::Bencher; + + use crate::test_support::rand_base64_size; + + #[bench] + fn avx2_1mb(b: &mut Bencher) { + let input = rand_base64_size(1024 * 1024); + b.iter(|| { + let ret = decode64(&input, test_avx2(), test_avx2()).unwrap(); + std::hint::black_box(ret); + }); + } + + #[bench] + fn lut_align64_1mb(b: &mut Bencher) { + let input = rand_base64_size(1024 * 1024); + b.iter(|| { + let ret = decode64(&input, lut_align64::LutAlign64, Simple).unwrap(); + std::hint::black_box(ret); + }); + } + + #[bench] + fn avx2_1kb(b: &mut Bencher) { + let input = rand_base64_size(1024); + b.iter(|| { + let ret = decode64(&input, test_avx2(), test_avx2()).unwrap(); + std::hint::black_box(ret); + }); + } + + #[bench] + fn lut_align64_1kb(b: &mut Bencher) { + let input = rand_base64_size(1024); + b.iter(|| { + let ret = decode64(&input, lut_align64::LutAlign64, Simple).unwrap(); + std::hint::black_box(ret); + }); + } +} diff --git a/src/encode/avx2.rs b/src/encode/avx2.rs new file mode 100644 index 0000000..4c60047 --- /dev/null +++ b/src/encode/avx2.rs @@ -0,0 +1,156 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use core::arch::x86_64::*; + +use crate::avx2::*; + +/// # Safety +/// The caller should ensure the requisite CPU features are enabled. +#[target_feature(enable = "avx2")] +unsafe fn encode_block(block: &mut ::Block, charset: crate::CharacterSet) { + let input = array_as_m256i(*block); + + // The general idea is to recognize that the 6-bit input can fall in one of + // five output ranges: uppercase letters, lowercase letters, numbers, + // special character 1, special character 2. First we calculate what range + // the input is in, then we determine what value would need to be added to + // arrive at the right ASCII output value, and add it. + + // Check whether the input should be encoded as a letter. + // + // If it should, result is now 0. Otherwise, it's in the range 1...12, + // inclusive. + let result = _mm256_subs_epu8(input, _mm256_set1_epi8(51)); + + // Check whether the input should be encoded as an uppercase letter. + // + // If it should, result is now 0xff. Otherwise, it's 0. + let less = _mm256_cmpgt_epi8(_mm256_set1_epi8(26), input); + + // Choose one of the 5 ranges for each input. + // + // 0: lowercase letter + // 1...10: number + // 11: special character 1 + // 12: special character 2 + // 13: uppercase letter + let result = _mm256_or_si256(result, _mm256_and_si256(less, _mm256_set1_epi8(13))); + + // Choose the lookup table based on the character set. + // + // The lookup table gives the amount that needs to be added to the input + // (AKA the shift) to convert it to the appropriate ASCII code. + let shift_lut = match charset { + crate::CharacterSet::Standard => dup_mm_setr_epi8([ + b'a' as i8 - 26, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'+' as i8 - 62, + b'/' as i8 - 63, + b'A' as _, + 0, + 0, + ]), + crate::CharacterSet::UrlSafe => dup_mm_setr_epi8([ + b'a' as i8 - 26, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'0' as i8 - 52, + b'-' as i8 - 62, + b'_' as i8 - 63, + b'A' as _, + 0, + 0, + ]), + }; + + let shift = _mm256_shuffle_epi8(shift_lut, result); + + *block = m256i_as_array(_mm256_add_epi8(shift, input)); +} + +/// # Safety +/// The caller should ensure the requisite CPU features are enabled. +#[target_feature(enable = "avx2")] +unsafe fn unpack_block( + input: &::Input, + output: &mut ::Output, +) { + let input = _mm256_set_m128i( + core::ptr::read_unaligned(input.as_ptr().offset(8) as _), + core::ptr::read_unaligned(input.as_ptr() as _), + ); + + #[rustfmt::skip] + let shuf = _mm256_set_epi8( + 14, 15, 13, 14, + 11, 12, 10, 11, + 8, 9, 7, 8, + 5, 6, 4, 5, + + 10, 11, 9, 10, + 7, 8, 6, 7, + 4, 5, 3, 4, + 1, 2, 0, 1, + ); + + let input = _mm256_shuffle_epi8(input, shuf); + + let t0 = _mm256_and_si256(input, _mm256_set1_epi32(0x0fc0fc00)); + let t1 = _mm256_mulhi_epu16(t0, _mm256_set1_epi32(0x04000040)); + let t2 = _mm256_and_si256(input, _mm256_set1_epi32(0x003f03f0)); + let t3 = _mm256_mullo_epi16(t2, _mm256_set1_epi32(0x01000010)); + *output = m256i_as_array(_mm256_or_si256(t1, t3)); +} + +#[derive(Copy, Clone)] +pub(super) struct Avx2 { + _private: (), +} + +impl Avx2 { + /// # Safety + /// The caller should ensure the requisite CPU features are enabled. + #[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")] + pub(super) unsafe fn new() -> Avx2 { + Avx2 { _private: () } + } +} + +impl super::Encoder for Avx2 { + type Block = [u8; 32]; + + fn encode_block(self, block: &mut Self::Block, charset: crate::CharacterSet) { + // safe: `self` was given as a witness that the features are available + unsafe { encode_block(block, charset) } + } +} + +impl super::Unpacker for Avx2 { + type Input = [u8; 24]; + type Output = [u8; 32]; + + fn unpack_block(self, input: &Self::Input, output: &mut Self::Output) { + // safe: `self` was given as a witness that the features are available + unsafe { unpack_block(input, output) } + } +} diff --git a/src/encode/lut_align64.rs b/src/encode/lut_align64.rs new file mode 100644 index 0000000..85414e0 --- /dev/null +++ b/src/encode/lut_align64.rs @@ -0,0 +1,26 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use crate::lut_align64::CacheLineLut; + +static LUT_STANDARD: CacheLineLut = CacheLineLut(crate::misc::LUT_STANDARD); +static LUT_URLSAFE: CacheLineLut = + CacheLineLut(*b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"); + +#[derive(Copy, Clone)] +pub(super) struct LutAlign64; + +impl super::Encoder for LutAlign64 { + type Block = [u8; 1]; + + fn encode_block(self, block: &mut Self::Block, charset: crate::CharacterSet) { + let lut = match charset { + crate::Standard => &LUT_STANDARD, + crate::UrlSafe => &LUT_URLSAFE, + }; + block[0] = lut.0[block[0] as usize]; + } +} diff --git a/src/encode/mod.rs b/src/encode/mod.rs new file mode 100644 index 0000000..0e155a2 --- /dev/null +++ b/src/encode/mod.rs @@ -0,0 +1,316 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +mod avx2; +mod lut_align64; + +use alloc::{string::String, vec::Vec}; + +trait Encoder: Copy { + type Block: AsRef<[u8]> + AsMut<[u8]> + Default; + + fn encode_block(self, block: &mut Self::Block, charset: crate::CharacterSet); +} + +trait Unpacker: Copy { + type Input: AsRef<[u8]> + AsMut<[u8]> + Default; + type Output: AsRef<[u8]> + AsMut<[u8]> + Default; + + fn unpack_block(self, input: &Self::Input, output: &mut Self::Output); +} + +#[derive(Copy, Clone)] +struct Simple; + +impl Unpacker for Simple { + type Input = [u8; 3]; + type Output = [u8; 4]; + + fn unpack_block(self, input: &Self::Input, output: &mut Self::Output) { + output[0] = input[0] >> 2; + output[1] = ((input[0] & 0x03) << 4) | (input[1] >> 4); + output[2] = ((input[1] & 0x0f) << 2) | (input[2] >> 6); + output[3] = (input[2] & 0x3f) << 0; + } +} + +trait Lcm { + type Array: AsRef<[u8]> + AsMut<[u8]> + Default; +} + +trait SplitArray { + type Output: AsRef<[T]> + AsMut<[T]>; + fn split_mut_internal(&mut self) -> &mut Self::Output; +} + +trait SplitArrayExt { + fn split_mut(&mut self) -> &mut [T] + where + Self: SplitArray, + { + self.split_mut_internal().as_mut() + } +} + +impl SplitArrayExt for T {} + +macro_rules! impl_lcm_array { + ($($am:ident / )* $a:literal, $($bm:ident / )* $b:literal, $lcm:literal) => { + impl Lcm for ([u8; $a], [u8; $b]) { + type Array = [u8; $lcm]; + } + + impl_lcm_array!(@split $($am / )* $a, $lcm); + impl_lcm_array!(@split $($bm / )* $b, $lcm); + }; + (@split $($nm:ident / )* $n:literal, $lcm:literal) => { + $(#[cfg(all(not($nm), $nm))])* + impl SplitArray<[T; $n]> for [T; $lcm] { + type Output = [[T; $n]; $lcm / $n]; + + fn split_mut_internal(&mut self) -> &mut Self::Output { + unsafe { &mut *(self as *mut _ as *mut _) } + } + } + }; +} + +impl_lcm_array!(32, skip / 32, 32); +impl_lcm_array!(skip / 32, 1, 32); +impl_lcm_array!(4, skip / 32, 32); +impl_lcm_array!(4, 1, 4); + +trait TakePrefix: Sized { + fn take_prefix(&mut self, mid: usize) -> Self; +} + +impl<'a, T: 'a> TakePrefix for &'a [T] { + fn take_prefix(&mut self, mid: usize) -> Self { + let prefix = &self[..mid]; + *self = &self[mid..]; + prefix + } +} + +impl crate::Newline { + fn append_to(self, buf: &mut Vec) { + if let crate::Newline::CRLF = self { + buf.push(b'\r'); + } + buf.push(b'\n'); + } +} + +fn encode64( + input: &[u8], + config: crate::Config, + encoder: E, + unpacker: U, +) -> String +where + (U::Output, E::Block): Lcm, + <(U::Output, E::Block) as Lcm>::Array: SplitArray + SplitArray, +{ + let mut len = crate::misc::div_roundup(input.len(), 3) * 4; + let mut next_nl = config.line_length; + if let Some(line_length) = config.line_length { + let nl_len = match config.newline { + crate::Newline::LF => 1, + crate::Newline::CRLF => 2, + }; + len = crate::misc::div_roundup(len, line_length) * (line_length + nl_len); + } + let mut output = Vec::with_capacity(len); + + let mut buffer = <(U::Output, E::Block) as Lcm>::Array::default(); + + let mut input_iter = input.chunks(core::mem::size_of::()); + while input_iter.len() > 0 { + let mut input_len = 0; + for chunk in buffer.split_mut::() { + let mut input_block = U::Input::default(); + if let Some(input_next) = input_iter.next() { + input_len += input_next.len(); + input_block.as_mut()[..input_next.len()].copy_from_slice(input_next); + } + unpacker.unpack_block(&input_block, chunk); + } + for chunk in buffer.split_mut::() { + encoder.encode_block(chunk, config.char_set); + } + + let mut buffer = &buffer.as_ref()[..crate::misc::div_roundup(input_len * 4, 3)]; + + if let Some(mut nl_index) = next_nl { + while (output.len() + buffer.len()) > nl_index { + let line = buffer.take_prefix(nl_index - output.len()); + output.extend_from_slice(&line); + config.newline.append_to(&mut output); + nl_index = output.len() + config.line_length.unwrap(); + } + next_nl = Some(nl_index); + } + + output.extend_from_slice(buffer); + } + + if config.pad { + if let Some(mut nl_index) = next_nl { + let trailer_length = match input.len() % 3 { + 1 => 2, + 2 => 1, + _ => 0, + }; + for _ in 0..trailer_length { + if output.len() == nl_index { + config.newline.append_to(&mut output); + nl_index = output.len() + config.line_length.unwrap(); + } + output.push(b'='); + } + } else if output.len() != len { + output.resize(len, b'='); + } + } + + String::from_utf8(output).unwrap() +} + +pub(super) fn encode64_arch(input: &[u8], config: crate::Config) -> String { + unsafe { + if is_x86_feature_detected!("avx2") { + let avx2 = avx2::Avx2::new(); + return encode64(input, config, avx2, avx2); + } + } + encode64(input, config, lut_align64::LutAlign64, Simple) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{Config, Newline, STANDARD, URL_SAFE}; + + pub(super) fn test_avx2() -> avx2::Avx2 { + unsafe { avx2::Avx2::new() } + } + + generate_tests![ + encoders: { + lut_align64, lut_align64::LutAlign64; + avx2, test_avx2(); + }, + unpackers: { + simple, Simple; + avx2, test_avx2(); + }, + tests: { + encode, + }, + ]; + + fn encode(encoder: E, unpacker: U) + where + (U::Output, E::Block): Lcm, + <(U::Output, E::Block) as Lcm>::Array: SplitArray + SplitArray, + { + static ENCODE_TESTS: &[(&[u8], Config, &str)] = &[ + // basic tests (from rustc-serialize) + (b"", STANDARD, ""), + (b"f", STANDARD, "Zg=="), + (b"fo", STANDARD, "Zm8="), + (b"foo", STANDARD, "Zm9v"), + (b"foob", STANDARD, "Zm9vYg=="), + (b"fooba", STANDARD, "Zm9vYmE="), + (b"foobar", STANDARD, "Zm9vYmFy"), + // with crlf break (from rustc-serialize) + (b"foobar", Config {line_length: Some(4), ..STANDARD}, "Zm9v\r\nYmFy"), + // with lf break (from rustc-serialize) + (b"foobar", Config {line_length: Some(4), newline: Newline::LF, ..STANDARD}, "Zm9v\nYmFy"), + // without padding (from rustc-serialize) + (b"f", Config {pad: false, ..STANDARD}, "Zg"), + (b"fo", Config {pad: false, ..STANDARD}, "Zm8"), + // URL safe (from rustc-serialize) + (&[251, 255], URL_SAFE, "-_8"), + (&[251, 255], STANDARD, "+/8="), + + // new tests + (b"f", Config {line_length: Some(1), ..STANDARD}, "Z\r\ng\r\n=\r\n="), + (b"fo", Config {line_length: Some(1), ..STANDARD}, "Z\r\nm\r\n8\r\n="), + (b"foob", Config {line_length: Some(4), ..STANDARD}, "Zm9v\r\nYg=="), + (b"foob", Config {line_length: Some(5), ..STANDARD}, "Zm9vY\r\ng=="), + (b"foob", Config {line_length: Some(6), ..STANDARD}, "Zm9vYg\r\n=="), + (b"foob", Config {line_length: Some(7), ..STANDARD}, "Zm9vYg=\r\n="), + (b"foob", Config {line_length: Some(8), ..STANDARD}, "Zm9vYg=="), + (b"foobfoo", Config {line_length: Some(3), ..STANDARD}, "Zm9\r\nvYm\r\nZvb\r\nw=="), + (b"foobfoo", Config {line_length: Some(4), ..STANDARD}, "Zm9v\r\nYmZv\r\nbw=="), + (b"foobfoo", Config {line_length: Some(5), ..STANDARD}, "Zm9vY\r\nmZvbw\r\n=="), + (b"\x00\x10\x83\x10\x51\x87\x20\x92\x8b\x30\xd3\x8f\x41\x14\x93\x51\x55\x97\x61\x96\x9b\x71\xd7\x9f", STANDARD, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdef"), + ]; + + for (input, config, expected) in ENCODE_TESTS { + let output = encode64(input, *config, encoder, unpacker); + if &output != expected { + panic!( + "Test failed. Expected specific output. \n\nInput: {:02x?}\nOutput: {}\nExpected output:{}\n\n", + input, + output, + expected + ); + } + } + } +} + +#[cfg(all(test, feature = "nightly"))] +mod benches { + use super::{tests::test_avx2, *}; + + use test::Bencher; + + use rand::{thread_rng, RngCore}; + + #[bench] + fn avx2_1mb(b: &mut Bencher) { + let mut input = std::vec![0; 1024*1024]; + thread_rng().fill_bytes(&mut input); + b.iter(|| { + let ret = encode64(&input, crate::STANDARD, test_avx2(), test_avx2()); + std::hint::black_box(ret); + }); + } + + #[bench] + fn lut_align64_1mb(b: &mut Bencher) { + let mut input = std::vec![0; 1024*1024]; + thread_rng().fill_bytes(&mut input); + b.iter(|| { + let ret = encode64(&input, crate::STANDARD, lut_align64::LutAlign64, Simple); + std::hint::black_box(ret); + }); + } + + #[bench] + fn avx2_1kb(b: &mut Bencher) { + let mut input = std::vec![0; 1024]; + thread_rng().fill_bytes(&mut input); + b.iter(|| { + let ret = encode64(&input, crate::STANDARD, test_avx2(), test_avx2()); + std::hint::black_box(ret); + }); + } + + #[bench] + fn lut_align64_1kb(b: &mut Bencher) { + let mut input = std::vec![0; 1024]; + thread_rng().fill_bytes(&mut input); + b.iter(|| { + let ret = encode64(&input, crate::STANDARD, lut_align64::LutAlign64, Simple); + std::hint::black_box(ret); + }); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..65eb09c --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,213 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! This crate provides an implementation of Base64 encoding/decoding that is +//! designed to be resistant against software side-channel attacks (such as +//! timing & cache attacks), see below for details. On certain platforms it +//! also uses SIMD making it very fast. This makes it suitable for e.g. +//! decoding cryptographic private keys in PEM format. +//! +//! The API is very similar to the base64 implementation in the old +//! rustc-serialize crate, making it easy to use in existing projects. +//! +//! # Resistance against Software Side-Channel Attacks +//! +//! An indistinguishable-time (colloquially: constant-time) implementation of +//! an algorithm has a runtime that's independent of the data being processed. +//! This indistinguishability is usually based on the control flow of the +//! program as well as its memory access pattern. In that case, +//! indistinguishability may be achieved by making sure the control flow and +//! memory access pattern don't depend on the data. Other factors, such as +//! instruction cycle count may also be consequential. +//! +//! See the [BearSSL page on constant-time cryptography] for more information. +//! +//! The runtime of the implementations in this crate is intended to be +//! dependent only on whitespace and the length of the valid data, not the data +//! itself. +//! +//! [BearSSL page on constant-time cryptography]: https://bearssl.org/constanttime.html +//! +//! # Implementation +//! +//! Depending on the runtime CPU architecture, this crate uses different +//! implementations with different security properties. +//! +//! * x86 with AVX2: All lookup tables are implemented with SIMD +//! instructions. No secret-dependent memory accceses. +//! * Other platforms: Lookups are limited to 64-byte aligned lookup tables. On +//! platforms with 64-byte cache lines this may be sufficient to prevent +//! certain cache side-channel attacks. However, it's known that this is [not +//! sufficient for all platforms]. +//! +//! [not sufficient on some platforms]: https://ts.data61.csiro.au/projects/TS/cachebleed/ + +#![no_std] +#![cfg_attr(all(test, feature = "nightly"), feature(test))] + +extern crate alloc; +#[cfg(any(test, feature = "std"))] +#[macro_use] +extern crate std; +#[cfg(all(test, feature = "nightly"))] +extern crate test; + +#[cfg(test)] +#[macro_use] +mod test_support; + +#[macro_use] +mod misc; + +mod avx2; +mod lut_align64; + +mod decode; +mod encode; + +use alloc::{string::String, vec::Vec}; + +pub use self::CharacterSet::*; + +/// Available encoding character sets +#[derive(Clone, Copy, Debug)] +pub enum CharacterSet { + /// The standard character set (uses `+` and `/`) + Standard, + /// The URL safe character set (uses `-` and `_`) + UrlSafe, +} + +/// Available newline types +#[derive(Clone, Copy, Debug)] +pub enum Newline { + /// A linefeed (i.e. Unix-style newline) + LF, + /// A carriage return and a linefeed (i.e. Windows-style newline) + CRLF, +} + +/// Contains configuration parameters for `to_base64`. +#[derive(Clone, Copy, Debug)] +pub struct Config { + /// Character set to use + pub char_set: CharacterSet, + /// Newline to use + pub newline: Newline, + /// True to pad output with `=` characters + pub pad: bool, + /// `Some(len)` to wrap lines at `len`, `None` to disable line wrapping + pub line_length: Option, +} + +/// Configuration for RFC 4648 standard base64 encoding +pub static STANDARD: Config = Config { + char_set: Standard, + newline: Newline::CRLF, + pad: true, + line_length: None, +}; + +/// Configuration for RFC 4648 base64url encoding +pub static URL_SAFE: Config = Config { + char_set: UrlSafe, + newline: Newline::CRLF, + pad: false, + line_length: None, +}; + +/// Configuration for RFC 2045 MIME base64 encoding +pub static MIME: Config = Config { + char_set: Standard, + newline: Newline::CRLF, + pad: true, + line_length: Some(76), +}; + +/// A trait for converting a value to base64 encoding. +pub trait ToBase64 { + /// Converts the value of `self` to a base64 value following the specified + /// format configuration, returning the owned string. + fn to_base64(&self, config: Config) -> String; +} + +impl ToBase64 for [u8] { + /// Turn a vector of `u8` bytes into a base64 string. + /// + /// # Example + /// + /// ```rust + /// use b64_ct::{ToBase64, STANDARD}; + /// + /// fn main () { + /// let str = [52,32].to_base64(STANDARD); + /// println!("base 64 output: {:?}", str); + /// } + /// ``` + fn to_base64(&self, config: Config) -> String { + encode::encode64_arch(self, config) + } +} + +impl<'a, T: ?Sized + ToBase64> ToBase64 for &'a T { + fn to_base64(&self, config: Config) -> String { + (**self).to_base64(config) + } +} + +#[doc(inline)] +pub use decode::Error as FromBase64Error; + +/// A trait for converting from base64 encoded values. +pub trait FromBase64 { + /// Converts the value of `self`, interpreted as base64 encoded data, into + /// an owned vector of bytes, returning the vector. + fn from_base64(&self) -> Result, FromBase64Error>; +} + +impl FromBase64 for str { + /// Convert any base64 encoded string (literal, `@`, `&`, or `~`) + /// to the byte values it encodes. + /// + /// You can use the `String::from_utf8` function to turn a `Vec` into a + /// string with characters corresponding to those values. + /// + /// # Example + /// + /// This converts a string literal to base64 and back. + /// + /// ```rust + /// use b64_ct::{ToBase64, FromBase64, STANDARD}; + /// + /// fn main () { + /// let hello_str = b"Hello, World".to_base64(STANDARD); + /// println!("base64 output: {}", hello_str); + /// let res = hello_str.from_base64(); + /// if res.is_ok() { + /// let opt_bytes = String::from_utf8(res.unwrap()); + /// if opt_bytes.is_ok() { + /// println!("decoded from base64: {:?}", opt_bytes.unwrap()); + /// } + /// } + /// } + /// ``` + #[inline] + fn from_base64(&self) -> Result, FromBase64Error> { + self.as_bytes().from_base64() + } +} + +impl FromBase64 for [u8] { + fn from_base64(&self) -> Result, FromBase64Error> { + decode::decode64_arch(self) + } +} + +impl<'a, T: ?Sized + FromBase64> FromBase64 for &'a T { + fn from_base64(&self) -> Result, FromBase64Error> { + (**self).from_base64() + } +} diff --git a/src/lut_align64.rs b/src/lut_align64.rs new file mode 100644 index 0000000..34154cf --- /dev/null +++ b/src/lut_align64.rs @@ -0,0 +1,8 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#[repr(align(64))] +pub(crate) struct CacheLineLut(pub [u8; 64]); diff --git a/src/misc.rs b/src/misc.rs new file mode 100644 index 0000000..ce22325 --- /dev/null +++ b/src/misc.rs @@ -0,0 +1,21 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#[cfg(not(feature = "std"))] +#[macro_export] +macro_rules! is_x86_feature_detected { + ($feat:literal) => { + cfg!(target_feature = $feat) + }; +} + +pub(crate) const LUT_STANDARD: [u8; 64] = + *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +#[inline(always)] +pub(crate) fn div_roundup(numerator: usize, denominator: usize) -> usize { + (numerator + denominator - 1) / denominator +} diff --git a/src/test_support.rs b/src/test_support.rs new file mode 100644 index 0000000..8a1b9b8 --- /dev/null +++ b/src/test_support.rs @@ -0,0 +1,56 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use rand::{distributions::Distribution, Rng}; + +struct Base64; + +impl Distribution for Base64 { + fn sample(&self, rng: &mut R) -> u8 { + crate::misc::LUT_STANDARD[(rng.next_u32() & 0x3f) as usize] + } +} + +pub fn rand_base64_size(s: usize) -> std::vec::Vec { + rand::thread_rng().sample_iter(&Base64).take(s).collect() +} + +// `with_cartesian_products!(m (a b) (c d));` expands to `m!(a c); m!(a d); m!(b c); m!(b d);`. +#[macro_export] +macro_rules! with_cartesian_products { + (@$m:ident $acc:tt ($($choice:tt)*) $rest:tt) => { + $(with_cartesian_products!(@$rest $acc $choice $m);)* + }; + (@() ($($acc:tt)*) $choice:tt $m:ident) => { + $m!($($acc)* $choice); + }; + (@($next:tt $($rest:tt)*) ($($acc:tt)*) $choice:tt $m:ident) => { + with_cartesian_products!(@$m ($($acc)* $choice) $next ($($rest)*)); + }; + ($m:ident $next:tt $($rest:tt)*) => { + with_cartesian_products!(@$m () $next ($($rest)*)); + }; +} + +#[macro_export] +macro_rules! generate_tests { + ( + $_a:ident<$a:ident>: { $($an:ident, $at:expr;)* }, + $_b:ident<$b:ident>: { $($bn:ident, $bt:expr;)* }, + tests: { $($tn:ident,)* }, + ) => { + with_cartesian_products!( generate_tests ((@ $a $b)) ($($tn)*) ($(($an, $at))*) ($(($bn, $bt))*) ); + }; + ((@ $a:ident $b:ident) $tn:ident ($an:ident, $at:expr) ($bn:ident, $bt:expr)) => { + paste::item! { + #[test] + #[allow(non_snake_case)] + fn [< $tn _ $a $an _ $b $bn >]() { + $tn($at, $bt); + } + } + } +}