diff --git a/der/src/asn1/context_specific.rs b/der/src/asn1/context_specific.rs index dc6b43aa8..79d13ad71 100644 --- a/der/src/asn1/context_specific.rs +++ b/der/src/asn1/context_specific.rs @@ -89,9 +89,7 @@ impl ContextSpecific { F: FnOnce(&mut R) -> Result, E: From, { - while let Some(octet) = reader.peek_byte() { - let tag = Tag::try_from(octet)?; - + while let Some(tag) = Tag::peek_optional(reader)? { if !tag.is_context_specific() || (tag.number() > tag_number) { break; } else if tag.number() == tag_number { diff --git a/der/src/asn1/optional.rs b/der/src/asn1/optional.rs index 26e24d683..5ad8a210a 100644 --- a/der/src/asn1/optional.rs +++ b/der/src/asn1/optional.rs @@ -10,8 +10,8 @@ where type Error = T::Error; fn decode>(reader: &mut R) -> Result, Self::Error> { - if let Some(byte) = reader.peek_byte() { - if T::can_decode(Tag::try_from(byte)?) { + if let Some(tag) = Tag::peek_optional(reader)? { + if T::can_decode(tag) { return T::decode(reader).map(Some); } } diff --git a/der/src/encode.rs b/der/src/encode.rs index eba1f262d..c1c69a1d0 100644 --- a/der/src/encode.rs +++ b/der/src/encode.rs @@ -73,7 +73,7 @@ where { /// Compute the length of this value in bytes when encoded as ASN.1 DER. fn encoded_len(&self) -> Result { - self.value_len().and_then(|len| len.for_tlv()) + self.value_len().and_then(|len| len.for_tlv(self.tag())) } /// Encode this value as ASN.1 DER using the provided [`Writer`]. diff --git a/der/src/header.rs b/der/src/header.rs index 3d2dcd568..4a2bfbf20 100644 --- a/der/src/header.rs +++ b/der/src/header.rs @@ -15,7 +15,7 @@ pub struct Header { impl Header { /// Maximum number of DER octets a header can be in this crate. - pub(crate) const MAX_SIZE: usize = 1 + Length::MAX_SIZE; + pub(crate) const MAX_SIZE: usize = Tag::MAX_SIZE + Length::MAX_SIZE; /// Create a new [`Header`] from a [`Tag`] and a specified length. /// diff --git a/der/src/length.rs b/der/src/length.rs index 0a992366b..a19058982 100644 --- a/der/src/length.rs +++ b/der/src/length.rs @@ -1,6 +1,6 @@ //! Length calculations for encoded ASN.1 DER values -use crate::{Decode, DerOrd, Encode, Error, ErrorKind, Reader, Result, SliceWriter, Writer}; +use crate::{Decode, DerOrd, Encode, Error, ErrorKind, Reader, Result, SliceWriter, Tag, Writer}; use core::{ cmp::Ordering, fmt, @@ -51,8 +51,8 @@ impl Length { /// Get the length of DER Tag-Length-Value (TLV) encoded data if `self` /// is the length of the inner "value" portion of the message. - pub fn for_tlv(self) -> Result { - Self::ONE + self.encoded_len()? + self + pub fn for_tlv(self, tag: Tag) -> Result { + tag.encoded_len()? + self.encoded_len()? + self } /// Perform saturating addition of two lengths. diff --git a/der/src/reader.rs b/der/src/reader.rs index fe77d0359..aecea7a7b 100644 --- a/der/src/reader.rs +++ b/der/src/reader.rs @@ -118,10 +118,7 @@ pub trait Reader<'r>: Sized { /// Peek at the next byte in the reader. #[deprecated(since = "0.8.0-rc.1", note = "use `Tag::peek` instead")] fn peek_tag(&self) -> Result { - match self.peek_byte() { - Some(byte) => byte.try_into(), - None => Err(Error::incomplete(self.input_len())), - } + Tag::peek(self) } /// Read a single byte. diff --git a/der/src/tag.rs b/der/src/tag.rs index d91197a62..cbabe6a4e 100644 --- a/der/src/tag.rs +++ b/der/src/tag.rs @@ -143,14 +143,38 @@ pub enum Tag { } impl Tag { - /// Peek at the next byte in the reader and attempt to decode it as a [`Tag`] value. + /// Maximum number of octets in a DER encoding of a [`Tag`] using the + /// rules implemented by this crate. + pub(crate) const MAX_SIZE: usize = 4; + + /// Peek at the next bytes in the reader and attempt to decode it as a [`Tag`] value. /// /// Does not modify the reader's state. pub fn peek<'a>(reader: &impl Reader<'a>) -> Result { - match reader.peek_byte() { - Some(byte) => byte.try_into(), - None => Err(Error::incomplete(reader.input_len())), + Self::peek_optional(reader)?.ok_or_else(|| Error::incomplete(reader.input_len())) + } + + pub(crate) fn peek_optional<'a>(reader: &impl Reader<'a>) -> Result> { + let mut buf = [0u8; Self::MAX_SIZE]; + + if reader.peek_into(&mut buf[0..1]).is_err() { + return Ok(None); + }; + + if let Ok(tag) = Self::from_der(&buf[0..1]) { + return Ok(Some(tag)); } + + for i in 2..Self::MAX_SIZE { + let slice = &mut buf[0..i]; + if reader.peek_into(slice).is_ok() { + if let Ok(tag) = Self::from_der(slice) { + return Ok(Some(tag)); + } + } + } + + Some(Self::from_der(&buf)).transpose() } /// Assert that this [`Tag`] matches the provided expected tag. @@ -174,14 +198,45 @@ impl Tag { } } - /// Get the [`TagNumber`] (lower 6-bits) for this tag. + /// Get the [`TagNumber`] for this tag. pub fn number(self) -> TagNumber { - TagNumber(self.octet() & TagNumber::MASK) + match self { + Tag::Boolean => TagNumber::N1, + Tag::Integer => TagNumber::N2, + Tag::BitString => TagNumber::N3, + Tag::OctetString => TagNumber::N4, + Tag::Null => TagNumber::N5, + Tag::ObjectIdentifier => TagNumber::N6, + Tag::Real => TagNumber::N9, + Tag::Enumerated => TagNumber::N10, + Tag::Utf8String => TagNumber::N12, + Tag::Sequence => TagNumber::N16, + Tag::Set => TagNumber::N17, + Tag::NumericString => TagNumber::N18, + Tag::PrintableString => TagNumber::N19, + Tag::TeletexString => TagNumber::N20, + Tag::VideotexString => TagNumber::N21, + Tag::Ia5String => TagNumber::N22, + Tag::UtcTime => TagNumber::N23, + Tag::GeneralizedTime => TagNumber::N24, + Tag::VisibleString => TagNumber::N26, + Tag::GeneralString => TagNumber::N27, + Tag::BmpString => TagNumber::N30, + Tag::Application { number, .. } => number, + Tag::ContextSpecific { number, .. } => number, + Tag::Private { number, .. } => number, + } } /// Does this tag represent a constructed (as opposed to primitive) field? pub fn is_constructed(self) -> bool { - self.octet() & CONSTRUCTED_FLAG != 0 + match self { + Tag::Sequence | Tag::Set => true, + Tag::Application { constructed, .. } + | Tag::ContextSpecific { constructed, .. } + | Tag::Private { constructed, .. } => constructed, + _ => false, + } } /// Is this an application tag? @@ -204,45 +259,6 @@ impl Tag { self.class() == Class::Universal } - /// Get the octet encoding for this [`Tag`]. - pub fn octet(self) -> u8 { - match self { - Tag::Boolean => 0x01, - Tag::Integer => 0x02, - Tag::BitString => 0x03, - Tag::OctetString => 0x04, - Tag::Null => 0x05, - Tag::ObjectIdentifier => 0x06, - Tag::Real => 0x09, - Tag::Enumerated => 0x0A, - Tag::Utf8String => 0x0C, - Tag::Sequence => 0x10 | CONSTRUCTED_FLAG, - Tag::Set => 0x11 | CONSTRUCTED_FLAG, - Tag::NumericString => 0x12, - Tag::PrintableString => 0x13, - Tag::TeletexString => 0x14, - Tag::VideotexString => 0x15, - Tag::Ia5String => 0x16, - Tag::UtcTime => 0x17, - Tag::GeneralizedTime => 0x18, - Tag::VisibleString => 0x1A, - Tag::GeneralString => 0x1B, - Tag::BmpString => 0x1E, - Tag::Application { - constructed, - number, - } - | Tag::ContextSpecific { - constructed, - number, - } - | Tag::Private { - constructed, - number, - } => self.class().octet(constructed, number), - } - } - /// Create an [`Error`] for an invalid [`Length`]. pub fn length_error(self) -> Error { ErrorKind::Length { tag: self }.into() @@ -271,85 +287,146 @@ impl Tag { } } -impl TryFrom for Tag { +impl<'a> Decode<'a> for Tag { type Error = Error; - fn try_from(byte: u8) -> Result { - let constructed = byte & CONSTRUCTED_FLAG != 0; - let number = TagNumber::try_from(byte & TagNumber::MASK)?; - - match byte { - 0x01 => Ok(Tag::Boolean), - 0x02 => Ok(Tag::Integer), - 0x03 => Ok(Tag::BitString), - 0x04 => Ok(Tag::OctetString), - 0x05 => Ok(Tag::Null), - 0x06 => Ok(Tag::ObjectIdentifier), - 0x09 => Ok(Tag::Real), - 0x0A => Ok(Tag::Enumerated), - 0x0C => Ok(Tag::Utf8String), - 0x12 => Ok(Tag::NumericString), - 0x13 => Ok(Tag::PrintableString), - 0x14 => Ok(Tag::TeletexString), - 0x15 => Ok(Tag::VideotexString), - 0x16 => Ok(Tag::Ia5String), - 0x17 => Ok(Tag::UtcTime), - 0x18 => Ok(Tag::GeneralizedTime), - 0x1A => Ok(Tag::VisibleString), - 0x1B => Ok(Tag::GeneralString), - 0x1E => Ok(Tag::BmpString), - 0x30 => Ok(Tag::Sequence), // constructed - 0x31 => Ok(Tag::Set), // constructed - 0x40..=0x7E => Ok(Tag::Application { - constructed, - number, - }), - 0x80..=0xBE => Ok(Tag::ContextSpecific { - constructed, - number, - }), - 0xC0..=0xFE => Ok(Tag::Private { - constructed, - number, - }), - _ => Err(ErrorKind::TagUnknown { byte }.into()), - } - } -} + fn decode>(reader: &mut R) -> Result { + let first_byte = reader.read_byte()?; + + let tag = match first_byte { + 0x01 => Tag::Boolean, + 0x02 => Tag::Integer, + 0x03 => Tag::BitString, + 0x04 => Tag::OctetString, + 0x05 => Tag::Null, + 0x06 => Tag::ObjectIdentifier, + 0x09 => Tag::Real, + 0x0A => Tag::Enumerated, + 0x0C => Tag::Utf8String, + 0x12 => Tag::NumericString, + 0x13 => Tag::PrintableString, + 0x14 => Tag::TeletexString, + 0x15 => Tag::VideotexString, + 0x16 => Tag::Ia5String, + 0x17 => Tag::UtcTime, + 0x18 => Tag::GeneralizedTime, + 0x1A => Tag::VisibleString, + 0x1B => Tag::GeneralString, + 0x1E => Tag::BmpString, + 0x30 => Tag::Sequence, // constructed + 0x31 => Tag::Set, // constructed + 0x40..=0x7F => { + let (constructed, number) = parse_parts(first_byte, reader)?; + + Tag::Application { + constructed, + number, + } + } + 0x80..=0xBF => { + let (constructed, number) = parse_parts(first_byte, reader)?; + + Tag::ContextSpecific { + constructed, + number, + } + } + 0xC0..=0xFF => { + let (constructed, number) = parse_parts(first_byte, reader)?; -impl From for u8 { - fn from(tag: Tag) -> u8 { - tag.octet() + Tag::Private { + constructed, + number, + } + } + byte => return Err(ErrorKind::TagUnknown { byte }.into()), + }; + + Ok(tag) } } -impl From<&Tag> for u8 { - fn from(tag: &Tag) -> u8 { - u8::from(*tag) +fn parse_parts<'a, R: Reader<'a>>(first_byte: u8, reader: &mut R) -> Result<(bool, TagNumber)> { + let constructed = first_byte & CONSTRUCTED_FLAG != 0; + let first_number_part = first_byte & TagNumber::MASK; + + if first_number_part != TagNumber::MASK { + return Ok((constructed, TagNumber::new(first_number_part.into()))); } -} -impl<'a> Decode<'a> for Tag { - type Error = Error; + let mut multi_byte_tag_number = 0; - fn decode>(reader: &mut R) -> Result { - reader.read_byte().and_then(Self::try_from) + for _ in 0..Tag::MAX_SIZE - 2 { + multi_byte_tag_number <<= 7; + + let byte = reader.read_byte()?; + multi_byte_tag_number |= u16::from(byte & 0x7F); + + if byte & 0x80 == 0 { + return Ok((constructed, TagNumber::new(multi_byte_tag_number))); + } + } + + let byte = reader.read_byte()?; + if multi_byte_tag_number > u16::MAX >> 7 || byte & 0x80 != 0 { + return Err(ErrorKind::TagNumberInvalid.into()); } + multi_byte_tag_number |= u16::from(byte & 0x7F); + + Ok((constructed, TagNumber::new(multi_byte_tag_number))) } impl Encode for Tag { + #[allow(clippy::cast_possible_truncation)] fn encoded_len(&self) -> Result { - Ok(Length::ONE) + let number = self.number().value(); + + let length = if number <= 30 { + Length::ONE + } else { + Length::new(number.ilog2() as u16 / 7 + 2) + }; + + Ok(length) } + #[allow(clippy::cast_possible_truncation)] fn encode(&self, writer: &mut impl Writer) -> Result<()> { - writer.write_byte(self.into()) + let mut first_byte = self.class() as u8 | u8::from(self.is_constructed()) << 5; + + let number = self.number().value(); + + if number <= 30 { + first_byte |= number as u8; + writer.write_byte(first_byte)?; + } else { + first_byte |= 0x1F; + writer.write_byte(first_byte)?; + + let extra_bytes = number.ilog2() as u16 / 7 + 1; + + for shift in (0..extra_bytes).rev() { + let mut byte = (number >> (shift * 7)) as u8 & 0x7f; + + if shift != 0 { + byte |= 0x80; + } + + writer.write_byte(byte)?; + } + } + + Ok(()) } } impl DerOrd for Tag { fn der_cmp(&self, other: &Self) -> Result { - Ok(self.octet().cmp(&other.octet())) + Ok(self + .class() + .cmp(&other.class()) + .then_with(|| self.is_constructed().cmp(&other.is_constructed())) + .then_with(|| self.number().cmp(&other.number()))) } } @@ -412,14 +489,16 @@ impl fmt::Display for Tag { impl fmt::Debug for Tag { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Tag(0x{:02x}: {})", u8::from(*self), self) + write!(f, "Tag(0x{:02x}: {})", self.number().value(), self) } } #[cfg(test)] mod tests { + use hex_literal::hex; + use super::{Class, Tag, TagNumber}; - use crate::{Length, Reader, SliceReader}; + use crate::{Decode, ErrorKind, Length, Reader, SliceReader}; #[test] fn tag_class() { @@ -484,4 +563,23 @@ mod tests { assert_eq!(Tag::peek(&reader).unwrap(), Tag::Integer); assert_eq!(reader.position(), Length::ZERO); // Position unchanged } + + #[test] + fn decoding() { + // valid tag number but must be in short form + assert_eq!( + ErrorKind::TagNumberInvalid, + Tag::from_der(&hex!("FF03")).unwrap_err().kind() + ); + // universal tag with long form + assert_eq!( + ErrorKind::TagNumberInvalid, + Tag::from_der(&hex!("1FFF")).unwrap_err().kind() + ); + // leading zeros in long form + assert_eq!( + ErrorKind::TagNumberInvalid, + Tag::from_der(&hex!("5F8020")).unwrap_err().kind() + ); + } } diff --git a/der/src/tag/class.rs b/der/src/tag/class.rs index ffb2a1e75..2a3bba533 100644 --- a/der/src/tag/class.rs +++ b/der/src/tag/class.rs @@ -1,6 +1,5 @@ //! Class of an ASN.1 tag. -use super::{TagNumber, CONSTRUCTED_FLAG}; use core::fmt; /// Class of an ASN.1 tag. @@ -30,14 +29,6 @@ pub enum Class { Private = 0b11000000, } -impl Class { - /// Compute the identifier octet for a tag number of this class. - #[allow(clippy::arithmetic_side_effects)] - pub(super) fn octet(self, constructed: bool, number: TagNumber) -> u8 { - self as u8 | number.value() | (u8::from(constructed) * CONSTRUCTED_FLAG) - } -} - impl fmt::Display for Class { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { diff --git a/der/src/tag/number.rs b/der/src/tag/number.rs index dfff9a961..9b261aedc 100644 --- a/der/src/tag/number.rs +++ b/der/src/tag/number.rs @@ -1,25 +1,25 @@ //! ASN.1 tag numbers use super::Tag; -use crate::{Error, ErrorKind, Result}; use core::fmt; /// ASN.1 tag numbers (i.e. lower 5 bits of a [`Tag`]). /// /// From X.690 Section 8.1.2.2: /// +/// Tag numbers ranging from zero to 30 (inclusive) can be represented as a +/// single identifier octet. +/// /// > bits 5 to 1 shall encode the number of the tag as a binary integer with /// > bit 5 as the most significant bit. /// -/// This library supports tag numbers ranging from zero to 30 (inclusive), -/// which can be represented as a single identifier octet. -/// /// Section 8.1.2.4 describes how to support multi-byte tag numbers, which are -/// encoded by using a leading tag number of 31 (`0b11111`). This library -/// deliberately does not support this: tag numbers greater than 30 are -/// disallowed. +/// encoded by using a leading tag number of 31 (`0b11111`). +/// +/// This library supports tag numbers with 16 bit values +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] -pub struct TagNumber(pub(super) u8); +pub struct TagNumber(pub u16); impl TagNumber { /// Tag number `0` @@ -118,20 +118,9 @@ impl TagNumber { /// Mask value used to obtain the tag number from a tag octet. pub(super) const MASK: u8 = 0b11111; - /// Maximum tag number supported (inclusive). - const MAX: u8 = 30; - /// Create a new tag number (const-friendly). - /// - /// Panics if the tag number is greater than `30`. - /// For a fallible conversion, use [`TryFrom`] instead. - pub const fn new(byte: u8) -> Self { - #[allow(clippy::panic)] - if byte > Self::MAX { - panic!("tag number out of range"); - } - - Self(byte) + pub const fn new(number: u16) -> Self { + Self(number) } /// Create an `APPLICATION` tag with this tag number. @@ -159,43 +148,13 @@ impl TagNumber { } /// Get the inner value. - pub fn value(self) -> u8 { + pub fn value(self) -> u16 { self.0 } } -impl TryFrom for TagNumber { - type Error = Error; - - fn try_from(byte: u8) -> Result { - match byte { - 0..=Self::MAX => Ok(Self(byte)), - _ => Err(ErrorKind::TagNumberInvalid.into()), - } - } -} - -impl From for u8 { - fn from(tag_number: TagNumber) -> u8 { - tag_number.0 - } -} - impl fmt::Display for TagNumber { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } - -// Implement by hand because the derive would create invalid values. -// Use the constructor to create a valid value. -#[cfg(feature = "arbitrary")] -impl<'a> arbitrary::Arbitrary<'a> for TagNumber { - fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - Ok(Self::new(u.int_in_range(0..=Self::MAX)?)) - } - - fn size_hint(depth: usize) -> (usize, Option) { - u8::size_hint(depth) - } -} diff --git a/pkcs1/src/version.rs b/pkcs1/src/version.rs index f880253f2..fcbbcf383 100644 --- a/pkcs1/src/version.rs +++ b/pkcs1/src/version.rs @@ -60,7 +60,7 @@ impl<'a> Decode<'a> for Version { impl Encode for Version { fn encoded_len(&self) -> der::Result { - der::Length::ONE.for_tlv() + der::Length::ONE.for_tlv(Self::TAG) } fn encode(&self, writer: &mut impl Writer) -> der::Result<()> { diff --git a/pkcs8/src/version.rs b/pkcs8/src/version.rs index d5a3c5747..352e9728f 100644 --- a/pkcs8/src/version.rs +++ b/pkcs8/src/version.rs @@ -35,7 +35,7 @@ impl<'a> Decode<'a> for Version { impl Encode for Version { fn encoded_len(&self) -> der::Result { - der::Length::from(1u8).for_tlv() + der::Length::from(1u8).for_tlv(Self::TAG) } fn encode(&self, writer: &mut impl Writer) -> der::Result<()> {