From 34b10322e27913441d73aacc08eaf856ce88d896 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Mon, 23 Mar 2020 00:00:59 -0500 Subject: [PATCH] Add `DnsNameInput` to make constructing `DnsName` more convenient. --- src/name.rs | 70 +++++++++++++++++++++++++++++++---------- tests/dns_name_tests.rs | 31 +++++++++--------- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/src/name.rs b/src/name.rs index e9d69d50..d88926d4 100644 --- a/src/name.rs +++ b/src/name.rs @@ -56,32 +56,70 @@ impl core::fmt::Display for InvalidDnsNameError { #[cfg(feature = "std")] impl ::std::error::Error for InvalidDnsNameError {} +pub trait DnsNameInput: AsRef<[u8]> + Sized { + type Storage: AsRef<[u8]>; + fn into_storage(self) -> Self::Storage; +} + +#[cfg(feature = "std")] +impl DnsNameInput for String { + type Storage = Box<[u8]>; + fn into_storage(self) -> Self::Storage { + self.into_boxed_str().into() + } +} + +#[cfg(feature = "std")] +impl DnsNameInput for Box<[u8]> { + type Storage = Box<[u8]>; + fn into_storage(self) -> Self::Storage { + self + } +} + +#[cfg(feature = "std")] +impl DnsNameInput for Vec { + type Storage = Box<[u8]>; + fn into_storage(self) -> Self::Storage { + self.into() + } +} + +impl<'a> DnsNameInput for &'a str { + type Storage = &'a [u8]; + fn into_storage(self) -> Self::Storage { + self.as_ref() + } +} + +impl<'a> DnsNameInput for &'a [u8] { + type Storage = &'a [u8]; + fn into_storage(self) -> Self::Storage { + self.as_ref() + } +} + impl DnsName where B: AsRef<[u8]>, { - /// TODO: docs - pub fn try_from_punycode_str<'a>(dns_name: &'a str) -> Result - where - B: From<&'a [u8]>, - { - Self::try_from_punycode(dns_name.as_ref()) - } - /// Constructs a `DnsName` from the given input if the input is a /// syntactically-valid DNS name. - pub fn try_from_punycode(dns_name: A) -> Result - where - A: AsRef<[u8]>, - A: Into, - { - if !is_valid_reference_dns_id(untrusted::Input::from(dns_name.as_ref())) { + pub fn try_from_punycode( + input: impl DnsNameInput, + ) -> Result { + if !is_valid_reference_dns_id(untrusted::Input::from(input.as_ref())) { return Err(InvalidDnsNameError); } - Ok(Self(dns_name.into())) + Ok(Self(input.into_storage())) } +} +impl DnsName +where + B: AsRef<[u8]>, +{ /// Borrows any `DnsName` as a `DnsName<&[u8]>`. /// /// Use `DnsName<&[u8]>` when you don't *need* to be generic over the @@ -93,7 +131,7 @@ where /// TODO: #[cfg(feature = "std")] - pub fn to_owned(&self) -> DnsName> { + pub fn into_owned(self) -> DnsName> { DnsName(Box::from(self.0.as_ref())) } diff --git a/tests/dns_name_tests.rs b/tests/dns_name_tests.rs index 45a5ebf0..636bfb00 100644 --- a/tests/dns_name_tests.rs +++ b/tests/dns_name_tests.rs @@ -1,6 +1,6 @@ // Copyright 2014-2017 Brian Smith. -use webpki::DnsName; +use webpki::{DnsName, DnsNameRef}; // (name, is_valid) static DNS_NAME_VALIDITY: &[(&[u8], bool)] = &[ @@ -422,7 +422,7 @@ const DNS_NAME_LOWERCASE_TEST_CASES: &[(&str, &str)] = &[ #[test] fn test_dns_name_ascii_lowercase_chars() { for (expected_lowercase, input) in DNS_NAME_LOWERCASE_TEST_CASES { - let dns_name: DnsName<&str> = DnsName::try_from_punycode(*input).unwrap(); + let dns_name: DnsNameRef = DnsName::try_from_punycode(*input).unwrap(); let actual_lowercase = dns_name.punycode_lowercase_bytes(); assert_eq!(expected_lowercase.len(), actual_lowercase.len()); @@ -438,7 +438,7 @@ fn test_dns_name_ascii_lowercase_chars() { #[test] fn test_dns_name_fmt() { for (expected_lowercase, input) in DNS_NAME_LOWERCASE_TEST_CASES { - let dns_name: DnsName<&str> = DnsName::try_from_punycode(*input).unwrap(); + let dns_name: DnsNameRef = DnsName::try_from_punycode(*input).unwrap(); // Test `Display` implementation. assert_eq!(*expected_lowercase, format!("{}", dns_name)); @@ -464,8 +464,8 @@ fn test_dns_name_eq_different_len() { ]; for (a, b) in NOT_EQUAL { - let a: DnsName<&str> = DnsName::try_from_punycode(*a).unwrap(); - let b: DnsName<&str> = DnsName::try_from_punycode(*b).unwrap(); + let a: DnsNameRef = DnsName::try_from_punycode(*a).unwrap(); + let b: DnsNameRef = DnsName::try_from_punycode(*b).unwrap(); assert_ne!(a, b) } } @@ -474,8 +474,8 @@ fn test_dns_name_eq_different_len() { #[test] fn test_dns_name_eq_case() { for (expected_lowercase, input) in DNS_NAME_LOWERCASE_TEST_CASES { - let a: DnsName<&str> = DnsName::try_from_punycode(*expected_lowercase).unwrap(); - let b: DnsName<&str> = DnsName::try_from_punycode(*input).unwrap(); + let a: DnsNameRef = DnsName::try_from_punycode(*expected_lowercase).unwrap(); + let b: DnsNameRef = DnsName::try_from_punycode(*input).unwrap(); assert_eq!(a, b); } } @@ -484,22 +484,19 @@ fn test_dns_name_eq_case() { #[cfg(feature = "std")] #[test] fn test_dns_name_eq_various_types() { - for (expected_lowercase, input) in DNS_NAME_LOWERCASE_TEST_CASES { - let a: DnsName<&str> = DnsName::try_from_punycode(*expected_lowercase).unwrap(); - let b: DnsName = DnsName::try_from_punycode(*input).unwrap(); - assert_eq!(a, b); - } + use webpki::DnsNameBox; for (expected_lowercase, input) in DNS_NAME_LOWERCASE_TEST_CASES { - let a: DnsName = DnsName::try_from_punycode(*expected_lowercase).unwrap(); - let b: DnsName<&[u8]> = DnsName::try_from_punycode(input.as_ref()).unwrap(); + let a: DnsNameRef = DnsName::try_from_punycode(*expected_lowercase).unwrap(); + let b: DnsNameBox = DnsName::try_from_punycode(*input).unwrap().into_owned(); assert_eq!(a, b); } for (expected_lowercase, input) in DNS_NAME_LOWERCASE_TEST_CASES { - let a: DnsName> = - DnsName::try_from_punycode(expected_lowercase.as_ref()).unwrap(); - let b: DnsName<&str> = DnsName::try_from_punycode(*input).unwrap(); + let a: DnsNameBox = DnsName::try_from_punycode(*expected_lowercase) + .unwrap() + .into_owned(); + let b: DnsNameRef = DnsName::try_from_punycode(*input).unwrap(); assert_eq!(a, b); } }