diff --git a/src/hash/md5.rs b/src/hash/md5.rs index 66af360..c342b26 100644 --- a/src/hash/md5.rs +++ b/src/hash/md5.rs @@ -74,6 +74,7 @@ pub struct Md5 { buffer: [u8; Self::BLOCK_LEN], state: [u32; 4], len: usize, // in bytes. + offset: usize, } impl Md5 { @@ -86,88 +87,69 @@ impl Md5 { buffer: [0u8; Self::BLOCK_LEN], state: INITIAL_STATE, len: 0, + offset: 0, } } pub fn update(&mut self, data: &[u8]) { - // TODO: - // Unlike Sha1 and Sha2, the length value in MD5 is defined as - // the length of the message mod 2^64 - ie: integer overflow is OK. - if data.len() == 0 { - return (); - } - - let mut n = self.len % Self::BLOCK_LEN; - if n != 0 { - let mut i = 0usize; - loop { - if n == 64 || i >= data.len() { - break; - } - self.buffer[n] = data[i]; - n += 1; + let mut i = 0usize; + while i < data.len() { + if self.offset < Self::BLOCK_LEN { + self.buffer[self.offset] = data[i]; + self.offset += 1; i += 1; - self.len += 1; } - - if self.len % Self::BLOCK_LEN != 0 { - return (); - } else { + + if self.offset == Self::BLOCK_LEN { transform(&mut self.state, &self.buffer); - - let data = &data[i..]; - if data.len() > 0 { - return self.update(data); - } - } - } - - if data.len() < 64 { - self.buffer[..data.len()].copy_from_slice(data); - self.len += data.len(); - } else if data.len() == 64 { - transform(&mut self.state, data); - self.len += 64; - } else if data.len() > 64 { - let blocks = data.len() / 64; - for i in 0..blocks { - transform(&mut self.state, &data[i*64..i*64+64]); - self.len += 64; - } - let data = &data[blocks*64..]; - if data.len() > 0 { - self.buffer[..data.len()].copy_from_slice(data); - self.len += data.len(); + self.offset = 0; + self.len += Self::BLOCK_LEN; } - } else { - unreachable!() } } pub fn finalize(mut self) -> [u8; Self::DIGEST_LEN] { - // last_block - let len_bits = u64::try_from(self.len).unwrap() * 8; - let n = self.len % Self::BLOCK_LEN; - if n == 0 { - let mut block = [0u8; 64]; - block[0] = 0x80; - block[56..].copy_from_slice(&len_bits.to_le_bytes()); - transform(&mut self.state, &block); - } else { - self.buffer[n] = 0x80; - for i in n+1..64 { - self.buffer[i] = 0; - } - if 64 - n - 1 >= 8 { - self.buffer[56..].copy_from_slice(&len_bits.to_le_bytes()); - transform(&mut self.state, &self.buffer); - } else { - transform(&mut self.state, &self.buffer); - let mut block = [0u8; 64]; - block[56..].copy_from_slice(&len_bits.to_le_bytes()); - transform(&mut self.state, &block); - } - } + let mlen = self.len + self.offset; // in bytes + let mlen_bits = mlen * 8; // in bits + + let plen_bits = 512 - (mlen_bits + 64 + 1) % 512 + 1; // pad len, in bits + assert_eq!(plen_bits % 8, 0); + let plen = plen_bits / 8; // pad len, in bytes + + // NOTE: MLEN + PLEN + MLEN_OCTETS (the length of the message before the padding bits were added) + let dlen = mlen + plen + 8; + + assert_eq!(dlen % Self::BLOCK_LEN, 0); + assert!(plen > 1); + + let mut padding_block: [u8; Self::BLOCK_LEN * 2] = [ + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + ]; + + let mlen_octets = u64::try_from(mlen_bits).unwrap(); + padding_block[plen..plen + 8].copy_from_slice(&mlen_octets.to_le_bytes()); + + let data = &padding_block[..plen + 8]; + self.update(data); + + assert_eq!(self.offset, 0); let mut output = [0u8; Self::DIGEST_LEN]; output[ 0.. 4].copy_from_slice(&self.state[0].to_le_bytes()); @@ -447,4 +429,17 @@ fn test_md5_long_message() { let msg = vec![b'a'; 1000_000]; let digest = [119, 7, 214, 174, 78, 2, 124, 112, 238, 162, 169, 53, 194, 41, 111, 33]; assert_eq!(Md5::oneshot(&msg), digest); -} \ No newline at end of file +} + +#[test] +fn test_md5_oneshot() { + let mut m1 = Md5::new(); + m1.update(&hex::decode("4b01a2d762fada9ede4d1034a13dc69c").unwrap()); + m1.update(&hex::decode("496d616b65746869735f4c6f6e6750617373506872617365466f725f7361666574795f323031395f30393238405f4021").unwrap()); + let h1 = m1.finalize(); + + let h2 = Md5::oneshot(&hex::decode("4b01a2d762fada9ede4d1034a13dc69c\ +496d616b65746869735f4c6f6e6750617373506872617365466f725f7361666574795f323031395f30393238405f4021").unwrap()); + + assert_eq!(h1, h2); +}