diff --git a/src/dns/server.rs b/src/dns/server.rs index 68aa78c694..561b55e538 100644 --- a/src/dns/server.rs +++ b/src/dns/server.rs @@ -21,7 +21,6 @@ use hickory_resolver::system_conf::read_system_conf; use hickory_server::ServerFuture; use hickory_server::authority::LookupError; use hickory_server::server::Request; -use itertools::Itertools; use once_cell::sync::Lazy; use rand::rng; use rand::seq::SliceRandom; @@ -47,9 +46,10 @@ use crate::drain::{DrainMode, DrainWatcher}; use crate::metrics::{DeferRecorder, IncrementRecorder, Recorder}; use crate::proxy::Error; use crate::state::DemandProxyState; -use crate::state::service::{IpFamily, Service}; +use crate::state::service::{IpFamily, Service, ServiceMatch}; use crate::state::workload::Workload; use crate::state::workload::address::Address; +use crate::strng::Strng; use crate::{config, dns}; const DEFAULT_TCP_REQUEST_TIMEOUT: u64 = 5; @@ -187,26 +187,6 @@ impl Server { } } -enum MatchReason<'a> { - Canonical(&'a Arc), - First(&'a Arc), - Namespace(&'a Arc), - PreferredNamespace(&'a Arc), - None, -} - -impl<'a> From> for Option<&'a Arc> { - fn from(value: MatchReason<'a>) -> Option<&'a Arc> { - match value { - MatchReason::Canonical(s) - | MatchReason::First(s) - | MatchReason::Namespace(s) - | MatchReason::PreferredNamespace(s) => Some(s), - MatchReason::None => None, - } - } -} - /// A DNS [Resolver] backed by the ztunnel [DemandProxyState]. struct Store { state: DemandProxyState, @@ -386,7 +366,7 @@ impl Store { let services: Vec> = state .services .get_by_host(&search_name_str) - .iter() + .into_iter() .flatten() // Remove things without a VIP, unless they are Kubernetes headless services. // This will trigger us to forward upstream. @@ -406,39 +386,18 @@ impl Store { }) // Get the service matching the client namespace. If no match exists, just // return the first service. - // .find_or_first(|service| service.namespace == client.namespace) - .cloned() .collect(); - let service: Option<&Arc> = services - .iter() - .fold_while(MatchReason::None, |r, s| { - if s.namespace == client.namespace { - itertools::FoldWhile::Done(MatchReason::Namespace(s)) - } else if s.canonical { - itertools::FoldWhile::Continue(MatchReason::Canonical(s)) - } else { - // TODO: deprecate preferred_service_namespace - // https://github.com/istio/ztunnel/issues/1709 - if let Some(preferred_namespace) = - self.prefered_service_namespace.as_ref() - && preferred_namespace.as_str() == s.namespace - && !matches!(r, MatchReason::Canonical(_)) - { - return itertools::FoldWhile::Continue( - MatchReason::PreferredNamespace(s), - ); - } - match r { - MatchReason::None => { - itertools::FoldWhile::Continue(MatchReason::First(s)) - } - _ => itertools::FoldWhile::Continue(r), - } - } - }) - .into_inner() + let preferred_namespace: Strng = self + .prefered_service_namespace + .as_deref() + .unwrap_or("") .into(); + let service: Option<&Arc> = ServiceMatch::find_best_match( + services.iter(), + Some(&client.namespace), + Some(&preferred_namespace), + ); // First, lookup the host as a service. if let Some(service) = service { @@ -1884,7 +1843,7 @@ mod tests { .unwrap() .iter() .map(|(_, addr)| *addr) - .collect_vec() + .collect() } fn kube_fqdn, S2: AsRef>(name: S1, ns: S2) -> String { diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index a4f6255eb4..f5a5748735 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -494,19 +494,10 @@ impl Inbound { local_workload: &Workload, hbone_host: &Strng, ) -> Result, Error> { - // Validate a service exists for the hostname - let services = state.read().find_service_by_hostname(hbone_host)?; - - services - .iter() - .max_by_key(|s| { - let is_local_namespace = s.namespace == local_workload.namespace; - match is_local_namespace { - true => 1, - false => 0, - } - }) - .cloned() + state + .read() + .services + .get_best_by_host(hbone_host, Some(&local_workload.namespace)) .ok_or_else(|| Error::NoHostname(hbone_host.to_string())) } diff --git a/src/state.rs b/src/state.rs index c04db06060..7eb7e2d97e 100644 --- a/src/state.rs +++ b/src/state.rs @@ -281,10 +281,14 @@ impl ProxyState { } /// Find services by hostname. - pub fn find_service_by_hostname(&self, hostname: &Strng) -> Result>, Error> { + pub fn find_service_by_hostname( + &self, + hostname: &Strng, + namespace: &Strng, + ) -> Result, Error> { // Hostnames for services are more common, so lookup service first and fallback to workload. self.services - .get_by_host(hostname) + .get_best_by_host(hostname, Some(namespace)) .ok_or_else(|| Error::NoHostname(hostname.to_string())) } diff --git a/src/state/service.rs b/src/state/service.rs index e621607a4b..8838fe59fb 100644 --- a/src/state/service.rs +++ b/src/state/service.rs @@ -371,6 +371,14 @@ impl ServiceStore { self.by_host.get(hostname).map(|v| v.to_vec()) } + // Returns the "best" [Srevice] matching the given hostname. + // If a namespace is provided, a Service from that namespace is preferred. + // Next, a Service marked `canonical` is prerferred. + pub fn get_best_by_host(&self, hostname: &Strng, ns: Option<&Strng>) -> Option> { + let services = self.get_by_host(hostname)?; + Some(ServiceMatch::find_best_match(services.iter(), ns, None)?.clone()) + } + pub fn get_by_workload(&self, workload: &Workload) -> Vec> { workload .services @@ -584,3 +592,64 @@ impl ServiceStore { self.staged_services.len() } } + +/// Represents the reason a service was matched during lookup. +/// Used with fold_while to implement priority-based service selection +/// with short-circuit on best match (namespace + primary hostname). +/// +/// Priority order (lower is better): Namespace > Canonical > First +pub enum ServiceMatch<'a> { + Canonical(&'a Arc), + Namespace(&'a Arc), + PreferredNamespace(&'a Arc), + First(&'a Arc), + None, +} + +impl<'a> From> for Option<&'a Arc> { + fn from(value: ServiceMatch<'a>) -> Option<&'a Arc> { + match value { + ServiceMatch::Canonical(s) + | ServiceMatch::First(s) + | ServiceMatch::Namespace(s) + | ServiceMatch::PreferredNamespace(s) => Some(s), + ServiceMatch::None => None, + } + } +} + +impl<'a> ServiceMatch<'a> { + /// Finds the best matching service from an iterator using fold_while. + /// Short-circuits on Namespace match - the best possible result. + pub fn find_best_match( + mut services: impl Iterator>, + client_ns: Option<&Strng>, + preferred_namespace: Option<&Strng>, + ) -> Option<&'a Arc> { + services + .fold_while(ServiceMatch::None, |r, s| { + if Some(&s.namespace) == client_ns { + itertools::FoldWhile::Done(ServiceMatch::Namespace(s)) + } else if s.canonical { + itertools::FoldWhile::Continue(ServiceMatch::Canonical(s)) + } else { + // TODO: deprecate preferred_service_namespace + // https://github.com/istio/ztunnel/issues/1709 + if let Some(preferred_namespace) = preferred_namespace + && preferred_namespace == &s.namespace + && !matches!(r, ServiceMatch::Canonical(_)) + { + return itertools::FoldWhile::Continue(ServiceMatch::PreferredNamespace(s)); + } + match r { + ServiceMatch::None => { + itertools::FoldWhile::Continue(ServiceMatch::First(s)) + } + _ => itertools::FoldWhile::Continue(r), + } + } + }) + .into_inner() + .into() + } +}