diff --git a/crates/uv-configuration/src/trusted_host.rs b/crates/uv-configuration/src/trusted_host.rs index eff4a52aaa55a..26e3e8b108367 100644 --- a/crates/uv-configuration/src/trusted_host.rs +++ b/crates/uv-configuration/src/trusted_host.rs @@ -3,19 +3,32 @@ use url::Url; /// A trusted host, which could be a host or a host-port pair. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum TrustedHost { - Host(String), - HostPort(String, u16), +pub struct TrustedHost { + scheme: Option, + host: String, + port: Option, } impl TrustedHost { + /// Returns `true` if the [`Url`] matches this trusted host. pub fn matches(&self, url: &Url) -> bool { - match self { - Self::Host(host) => url.host_str() == Some(host.as_str()), - Self::HostPort(host, port) => { - url.host_str() == Some(host.as_str()) && url.port() == Some(*port) - } + if self + .scheme + .as_ref() + .map_or(true, |scheme| scheme != url.scheme()) + { + return false; + } + + if self.port.map_or(true, |port| url.port() != Some(port)) { + return false; } + + if Some(self.host.as_ref()) != url.host_str() { + return false; + } + + true } } @@ -31,27 +44,32 @@ impl std::str::FromStr for TrustedHost { type Err = TrustedHostError; fn from_str(s: &str) -> Result { - // Strip `http://` or `https://`. - let s = s - .strip_prefix("https://") - .unwrap_or_else(|| s.strip_prefix("http://").unwrap_or(s)); + // Detect scheme. + let (scheme, s) = if let Some(s) = s.strip_prefix("https://") { + (Some("https".to_string()), s) + } else if let Some(s) = s.strip_prefix("http://") { + (Some("http".to_string()), s) + } else { + (None, s) + }; - // Split into host and scheme. let mut parts = s.splitn(2, ':'); + + // Detect host. let host = parts .next() .and_then(|host| host.split('/').next()) + .map(ToString::to_string) .ok_or_else(|| TrustedHostError::MissingHost(s.to_string()))?; + + // Detect port. let port = parts .next() .map(str::parse) .transpose() .map_err(|_| TrustedHostError::InvalidPort(s.to_string()))?; - match port { - Some(port) => Ok(TrustedHost::HostPort(host.to_string(), port)), - None => Ok(TrustedHost::Host(host.to_string())), - } + Ok(Self { scheme, host, port }) } } @@ -80,24 +98,40 @@ mod tests { fn parse() { assert_eq!( "example.com".parse::().unwrap(), - super::TrustedHost::Host("example.com".to_string()) + super::TrustedHost { + scheme: None, + host: "example.com".to_string(), + port: None + } ); assert_eq!( "example.com:8080".parse::().unwrap(), - super::TrustedHost::HostPort("example.com".to_string(), 8080) + super::TrustedHost { + scheme: None, + host: "example.com".to_string(), + port: Some(8080) + } ); assert_eq!( "https://example.com".parse::().unwrap(), - super::TrustedHost::Host("example.com".to_string()) + super::TrustedHost { + scheme: Some("https".to_string()), + host: "example.com".to_string(), + port: None + } ); assert_eq!( "https://example.com/hello/world" .parse::() .unwrap(), - super::TrustedHost::Host("example.com".to_string()) + super::TrustedHost { + scheme: Some("https".to_string()), + host: "example.com".to_string(), + port: None + } ); } }