Skip to content

Commit

Permalink
Remove SHA peripheral requirement.
Browse files Browse the repository at this point in the history
- Document safety of the SHA driver.
  • Loading branch information
AnthonyGrondin committed Aug 13, 2024
1 parent 2013d4d commit 83236d5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 86 deletions.
131 changes: 62 additions & 69 deletions esp-hal/src/sha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
//! # use nb::block;
//! let source_data = "HELLO, ESPRESSIF!".as_bytes();
//! let mut remaining = source_data;
//! let mut hasher = Sha256::new(peripherals.SHA);
//! let mut hasher = Sha256::new();
//! // Short hashes can be created by decreasing the output buffer to the
//! // desired length
//! let mut output = [0u8; 32];
Expand Down Expand Up @@ -66,8 +66,6 @@ use core::{convert::Infallible, marker::PhantomData};
pub use digest::Digest;

use crate::{
peripheral::{Peripheral, PeripheralRef},
peripherals::SHA,
reg_access::{AlignmentHelper, SocDependentEndianess},
system::PeripheralClockControl,
};
Expand Down Expand Up @@ -126,7 +124,7 @@ impl<DM: crate::Mode> Context<DM> {

// This implementation might fail after u32::MAX/8 bytes, to increase please see
// ::finish() length/self.cursor usage
pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {
pub trait Sha<DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {
/// Constant containing the name of the algorithm as a string.
const ALGORITHM: &'static str;

Expand All @@ -149,6 +147,7 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {

#[cfg(not(esp32))]
fn is_busy(&self) -> bool {
// Safety: This is safe because we only read `SHA_BUSY_REG`
let sha = unsafe { crate::peripherals::SHA::steal() };
sha.busy().read().bits() != 0
}
Expand All @@ -158,6 +157,8 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {

#[cfg(not(esp32))]
fn process_buffer(&mut self) {
// Safety: This is safe because digest state is restored and saved between
// operations.
let sha = unsafe { crate::peripherals::SHA::steal() };
// Setup SHA Mode before processing current buffer.
sha.mode()
Expand All @@ -178,29 +179,43 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {
// SET SHA_CONTINUE_REG
sha.continue_().write(|w| unsafe { w.bits(1) });
}

// Wait until buffer has completely processed
while self.is_busy() {}

// Save the content of the current hash for interleaving operation.
let mut saved_digest = [0u8; 64];
self.alignment_helper.volatile_read_regset(
sha.h_mem(0).as_ptr(),
&mut saved_digest,
64 / self.alignment_helper.align_size(),
);
self.saved_digest.replace(saved_digest);
}

fn flush_data(&mut self) -> nb::Result<(), Infallible> {
let sha = unsafe { crate::peripherals::SHA::steal() };
if self.is_busy() {
return Err(nb::Error::WouldBlock);
}

// Safety: This is safe because the buffer is processed after being flushed to memory.
let sha = unsafe { crate::peripherals::SHA::steal() };

let chunk_len = self.chunk_length();
let ctx = self.deref_mut();

// Flush aligned buffer in memory before flushing alignment_helper
unsafe {
core::ptr::copy_nonoverlapping(
self.buffer.as_ptr(),
ctx.buffer.as_ptr(),
#[cfg(esp32)]
sha.text(0).as_ptr(),
#[cfg(not(esp32))]
sha.m_mem(0).as_ptr(),
32,
(ctx.cursor % chunk_len) / ctx.alignment_helper.align_size(),
);
}

let ctx = self.deref_mut();
let flushed = ctx.alignment_helper.flush_to(
#[cfg(esp32)]
sha.text(0).as_ptr(),
Expand All @@ -221,7 +236,6 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {
// This function ensures that incoming data is aligned to u32 (due to issues
// with cpy_mem<u8>)
fn write_data<'a>(&mut self, incoming: &'a [u8]) -> nb::Result<&'a [u8], Infallible> {
let sha = unsafe { crate::peripherals::SHA::steal() };
let mod_cursor = self.cursor % self.chunk_length();

let chunk_len = self.chunk_length();
Expand All @@ -239,7 +253,10 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {

// If bound reached we write the buffer to memory and process it.
if bound_reached {
// Safety: This is safe because the bound has been reached and the buffer will
// be fully processed then saved.
unsafe {
let sha = crate::peripherals::SHA::steal();
core::ptr::copy_nonoverlapping(
self.buffer.as_ptr(),
#[cfg(esp32)]
Expand All @@ -250,21 +267,6 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {
);
}
self.process_buffer();

// Save the content of the current hash for interleaving operation.
#[cfg(not(esp32))]
{
// Wait until buffer has completely processed
while self.is_busy() {}

let mut saved_digest = [0u8; 64];
self.alignment_helper.volatile_read_regset(
sha.h_mem(0).as_ptr(),
&mut saved_digest,
64 / self.alignment_helper.align_size(),
);
self.saved_digest.replace(saved_digest);
}
}

Ok(remaining)
Expand Down Expand Up @@ -384,71 +386,57 @@ pub trait Sha<'d, DM: crate::Mode>: core::ops::DerefMut<Target = Context<DM>> {

Ok(())
}

/// Create a new instance in [crate::Blocking] mode.
#[cfg_attr(not(esp32), doc = "Optionally an interrupt handler can be bound.")]
fn new_internal(sha: impl Peripheral<P = SHA> + 'd) -> (PeripheralRef<'d, SHA>, Context<DM>) {
crate::into_ref!(sha);

PeripheralClockControl::reset(crate::system::Peripheral::Sha);
PeripheralClockControl::enable(crate::system::Peripheral::Sha);

(
sha,
Context {
cursor: 0,
first_run: true,
finished: false,
alignment_helper: AlignmentHelper::default(),
buffer: [0u32; 32],
#[cfg(not(esp32))]
saved_digest: None,
phantom: PhantomData,
},
)
}
}

/// This macro implements the Sha<'a, DM> trait for a specified Sha algorithm
/// and a set of parameters
macro_rules! impl_sha {
($name: ident, $mode_bits: tt, $digest_length: tt, $chunk_length: tt) => {
pub struct $name<'d, DM: crate::Mode>(PeripheralRef<'d, SHA>, Context<DM>);
pub struct $name<DM: crate::Mode>(Context<DM>);

impl<'d> $name<'d, crate::Blocking> {
impl $name<crate::Blocking> {
/// Create a new instance in [crate::Blocking] mode.
#[cfg_attr(not(esp32), doc = "Optionally an interrupt handler can be bound.")]
pub fn new(sha: impl Peripheral<P = SHA> + 'd) -> $name<'d, crate::Blocking> {
let (sha, ctx) = Self::new_internal(sha);
Self(sha, ctx)
pub fn new() -> $name<crate::Blocking> {
Self::default()
}
}

/// Automatically implement Deref + DerefMut to get access to inner context
impl<'a, DM: crate::Mode> core::ops::Deref for $name<'a, DM> {
impl<DM: crate::Mode> core::ops::Deref for $name<DM> {
type Target = Context<DM>;

fn deref(&self) -> &Self::Target {
&self.1
&self.0
}
}

impl<'a, DM: crate::Mode> core::ops::DerefMut for $name<'a, DM> {
impl<DM: crate::Mode> core::ops::DerefMut for $name<DM> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.1
&mut self.0
}
}

/// Implement Default to create hasher out of thin air
impl<'d> core::default::Default for $name<'d, crate::Blocking> {
impl core::default::Default for $name<crate::Blocking> {
fn default() -> Self {
let sha = unsafe { crate::peripherals::SHA::steal() };
let (sha, ctx) = Self::new_internal(sha);
Self(sha, ctx)
PeripheralClockControl::reset(crate::system::Peripheral::Sha);
PeripheralClockControl::enable(crate::system::Peripheral::Sha);

Self(Context {
cursor: 0,
first_run: true,
finished: false,
alignment_helper: AlignmentHelper::default(),
buffer: [0u32; 32],
#[cfg(not(esp32))]
saved_digest: None,
phantom: PhantomData,
})
}
}

impl<'d> $crate::sha::Sha<'d, crate::Blocking> for $name<'d, crate::Blocking> {
impl $crate::sha::Sha<crate::Blocking> for $name<crate::Blocking> {
const ALGORITHM: &'static str = stringify!($name);

#[cfg(not(esp32))]
Expand All @@ -467,26 +455,31 @@ macro_rules! impl_sha {
// ESP32 uses different registers for its operation
#[cfg(esp32)]
fn load_reg(&self) {
// Safety: This is safe because digest state is restored and saved between
// operations.
let sha = unsafe { crate::peripherals::SHA::steal() };
paste::paste! {
unsafe { self.0.[< $name:lower _load >]().write(|w| w.bits(1)) };
unsafe { sha.[< $name:lower _load >]().write(|w| w.bits(1)) };
}
}

#[cfg(esp32)]
fn is_busy(&self) -> bool {
let sha = unsafe { crate::peripherals::SHA::steal() };
paste::paste! {
self.0.[< $name:lower _busy >]().read().[< $name:lower _busy >]().bit_is_set()
sha.[< $name:lower _busy >]().read().[< $name:lower _busy >]().bit_is_set()
}
}

#[cfg(esp32)]
fn process_buffer(&mut self) {
let sha = unsafe { crate::peripherals::SHA::steal() };
paste::paste! {
if self.first_run {
self.0.[< $name:lower _start >]().write(|w| unsafe { w.bits(1) });
sha.[< $name:lower _start >]().write(|w| unsafe { w.bits(1) });
self.first_run = false;
} else {
self.0.[< $name:lower _continue >]().write(|w| unsafe { w.bits(1) });
sha.[< $name:lower _continue >]().write(|w| unsafe { w.bits(1) });
}
}
}
Expand All @@ -496,10 +489,10 @@ macro_rules! impl_sha {
/// Note: digest has a blanket trait implementation for [digest::Digest] for any
/// element that implements FixedOutput + Default + Update + HashMarker
#[cfg(feature = "digest")]
impl<'d, DM: crate::Mode> digest::HashMarker for $name<'d, DM> {}
impl<DM: crate::Mode> digest::HashMarker for $name<DM> {}

#[cfg(feature = "digest")]
impl<'a, DM: crate::Mode> digest::OutputSizeUser for $name<'a, DM> {
impl<DM: crate::Mode> digest::OutputSizeUser for $name<DM> {
// We use paste to append `U` to the digest size to match a const defined in
// digest
paste::paste! {
Expand All @@ -508,7 +501,7 @@ macro_rules! impl_sha {
}

#[cfg(feature = "digest")]
impl<'a> digest::Update for $name<'a, crate::Blocking> {
impl digest::Update for $name<crate::Blocking> {
fn update(&mut self, data: &[u8]) {
let mut remaining = data.as_ref();
while remaining.len() > 0 {
Expand All @@ -518,7 +511,7 @@ macro_rules! impl_sha {
}

#[cfg(feature = "digest")]
impl<'a> digest::FixedOutput for $name<'a, crate::Blocking> {
impl digest::FixedOutput for $name<crate::Blocking> {
fn finalize_into(mut self, out: &mut digest::Output<Self>) {
nb::block!(self.finish(out)).unwrap()
}
Expand Down
25 changes: 8 additions & 17 deletions hil-test/tests/sha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ mod tests {

#[test]
fn test_sha_1() {
let peripherals = Peripherals::take();
let mut sha = Sha1::new(peripherals.SHA);
let mut sha = Sha1::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -90,8 +89,7 @@ mod tests {
#[test]
#[cfg(not(feature = "esp32"))]
fn test_sha_224() {
let peripherals = Peripherals::take();
let mut sha = esp_hal::sha::Sha224::new(peripherals.SHA);
let mut sha = esp_hal::sha::Sha224::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -128,8 +126,7 @@ mod tests {

#[test]
fn test_sha_256() {
let peripherals = Peripherals::take();
let mut sha = Sha256::new(peripherals.SHA);
let mut sha = Sha256::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -168,8 +165,7 @@ mod tests {
#[test]
#[cfg(any(feature = "esp32", feature = "esp32s2", feature = "esp32s3"))]
fn test_sha_384() {
let peripherals = Peripherals::take();
let mut sha = esp_hal::sha::Sha384::new(peripherals.SHA);
let mut sha = esp_hal::sha::Sha384::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -210,8 +206,7 @@ mod tests {
#[test]
#[cfg(any(feature = "esp32", feature = "esp32s2", feature = "esp32s3"))]
fn test_sha_512() {
let peripherals = Peripherals::take();
let mut sha = esp_hal::sha::Sha512::new(peripherals.SHA);
let mut sha = esp_hal::sha::Sha512::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -253,8 +248,7 @@ mod tests {
#[test]
#[cfg(any(feature = "esp32s2", feature = "esp32s3"))]
fn test_sha_512_224() {
let peripherals = Peripherals::take();
let mut sha = esp_hal::sha::Sha512_224::new(peripherals.SHA);
let mut sha = esp_hal::sha::Sha512_224::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -293,8 +287,7 @@ mod tests {
#[test]
#[cfg(any(feature = "esp32s2", feature = "esp32s3"))]
fn test_sha_512_256() {
let peripherals = Peripherals::take();
let mut sha = esp_hal::sha::Sha512_256::new(peripherals.SHA);
let mut sha = esp_hal::sha::Sha512_256::new();

let source_data = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".as_bytes();
let mut remaining = source_data;
Expand Down Expand Up @@ -618,9 +611,7 @@ mod tests {
/// A simple test using [esp_hal::sha::Sha] trait to test hashing for an
/// algorithm against a specific size. This will compare the result with a
/// software implementation and return false if there's a mismatch
fn test_for_size<'a, D: Digest + Default + Sha<'a, esp_hal::Blocking>, const N: usize>(
size: usize,
) {
fn test_for_size<D: Digest + Default + Sha<esp_hal::Blocking>, const N: usize>(size: usize) {
let source_data = unsafe { core::slice::from_raw_parts(CHAR_ARRAY.as_ptr(), size) };
let mut remaining = source_data;
let mut hasher = D::default();
Expand Down

0 comments on commit 83236d5

Please sign in to comment.