diff --git a/crates/lib/src/install.rs b/crates/lib/src/install.rs index 2afc12616..bada00301 100644 --- a/crates/lib/src/install.rs +++ b/crates/lib/src/install.rs @@ -1668,17 +1668,15 @@ fn find_root_args_to_inherit(cmdline: &Cmdline, root_info: &Filesystem) -> Resul .context("Parsing root= karg")?; // If we have a root= karg, then use that let (mount_spec, kargs) = if let Some(root) = root { - let rootflags = cmdline.find(crate::kernel_cmdline::ROOTFLAGS); - let inherit_kargs = cmdline.iter().filter(|arg| { - arg.key - .starts_with(crate::kernel_cmdline::INITRD_ARG_PREFIX) - }); + let rootflags = cmdline.find_str(crate::kernel_cmdline::ROOTFLAGS); + let inherit_kargs = + cmdline.find_all_starting_with_str(crate::kernel_cmdline::INITRD_ARG_PREFIX); ( root.to_owned(), rootflags .into_iter() .chain(inherit_kargs) - .map(|p| p.to_string()) + .map(|p| p.as_ref().to_owned()) .collect(), ) } else { diff --git a/crates/lib/src/kernel_cmdline.rs b/crates/lib/src/kernel_cmdline.rs index e236a4c54..a961d2f31 100644 --- a/crates/lib/src/kernel_cmdline.rs +++ b/crates/lib/src/kernel_cmdline.rs @@ -8,9 +8,9 @@ use std::borrow::Cow; use anyhow::Result; /// This is used by dracut. -pub(crate) const INITRD_ARG_PREFIX: &[u8] = b"rd."; +pub(crate) const INITRD_ARG_PREFIX: &str = "rd."; /// The kernel argument for configuring the rootfs flags. -pub(crate) const ROOTFLAGS: &[u8] = b"rootflags"; +pub(crate) const ROOTFLAGS: &str = "rootflags"; /// A parsed kernel command line. /// @@ -41,7 +41,7 @@ impl<'a> Cmdline<'a> { /// Properly handles quoted values containing whitespace and splits on /// unquoted whitespace characters. Parameters are parsed as either /// key-only switches or key=value pairs. - pub fn iter(&'a self) -> impl Iterator> { + pub fn iter(&'a self) -> impl Iterator> + 'a { let mut in_quotes = false; self.0 @@ -63,6 +63,29 @@ impl<'a> Cmdline<'a> { self.iter().find(|p| p.key == key) } + /// Locate a kernel argument with the given key name that must be UTF-8. + /// + /// Otherwise the same as [`Self::find`]. + pub fn find_str(&'a self, key: &str) -> Option> { + let key = ParameterKeyStr(key); + self.iter() + .filter_map(|p| p.to_str()) + .find(move |p| p.key == key) + } + + /// Find all kernel arguments starting with the given prefix which must be UTF-8. + /// Non-UTF8 values are ignored. + /// + /// This is a variant of [`Self::find`]. + pub fn find_all_starting_with_str( + &'a self, + prefix: &'a str, + ) -> impl Iterator> + 'a { + self.iter() + .filter_map(|p| p.to_str()) + .filter(move |p| p.key.0.starts_with(prefix)) + } + /// Locate the value of the kernel argument with the given key name. /// /// Returns the first value matching the given key, or `None` if not found. @@ -121,47 +144,52 @@ impl<'a> From<&'a [u8]> for ParameterKey<'a> { } } +/// A single kernel command line parameter key that is known to be UTF-8. +/// +/// Otherwise the same as [`ParameterKey`]. +#[derive(Debug, Eq)] +pub(crate) struct ParameterKeyStr<'a>(&'a str); + +impl<'a> From<&'a str> for ParameterKeyStr<'a> { + fn from(value: &'a str) -> Self { + Self(value) + } +} + /// A single kernel command line parameter. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Eq)] pub(crate) struct Parameter<'a> { + /// The full original value + pub parameter: &'a [u8], /// The parameter key as raw bytes pub key: ParameterKey<'a>, /// The parameter value as raw bytes, if present pub value: Option<&'a [u8]>, } -impl<'a> Parameter<'a> { - /// Create a new parameter with the provided key and value. - #[cfg(test)] - pub fn new_kv<'k: 'a, 'v: 'a>(key: &'k [u8], value: &'v [u8]) -> Self { - Self { - key: ParameterKey(key), - value: Some(value), - } - } - - /// Create a new parameter with the provided key. - #[cfg(test)] - pub fn new_key(key: &'a [u8]) -> Self { - Self { - key: ParameterKey(key), - value: None, - } - } +/// A single kernel command line parameter. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct ParameterStr<'a> { + /// The original value + pub parameter: &'a str, + /// The parameter key + pub key: ParameterKeyStr<'a>, + /// The parameter value, if present + pub value: Option<&'a str>, +} - /// Returns the key as a lossy UTF-8 string. - /// - /// Invalid UTF-8 sequences are replaced with the Unicode replacement character. - pub fn key_lossy(&self) -> String { - String::from_utf8_lossy(&self.key).to_string() +impl<'a> Parameter<'a> { + pub fn to_str(&self) -> Option> { + let Ok(parameter) = std::str::from_utf8(self.parameter) else { + return None; + }; + Some(ParameterStr::from(parameter)) } +} - /// Returns the value as a lossy UTF-8 string. - /// - /// Invalid UTF-8 sequences are replaced with the Unicode replacement character. - /// Returns an empty string if no value is present. - pub fn value_lossy(&self) -> String { - String::from_utf8_lossy(self.value.unwrap_or(&[])).to_string() +impl<'a> AsRef for ParameterStr<'a> { + fn as_ref(&self) -> &str { + self.parameter } } @@ -177,6 +205,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a T> for Parameter<'a> { match equals { None => Self { + parameter: input, key: ParameterKey(input), value: None, }, @@ -196,6 +225,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a T> for Parameter<'a> { .unwrap_or(value); Self { + parameter: input, key, value: Some(value), } @@ -228,29 +258,40 @@ impl PartialEq for ParameterKey<'_> { } } -impl std::fmt::Display for Parameter<'_> { - /// Formats the parameter for display. - /// - /// Key-only parameters are displayed as just the key. - /// Key-value parameters are displayed as `key=value`. - /// Values containing whitespace are automatically quoted. - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let key = self.key_lossy(); - - if self.value.is_some() { - let value = self.value_lossy(); - - if value.chars().any(|c| c.is_ascii_whitespace()) { - write!(f, "{key}=\"{value}\"") - } else { - write!(f, "{key}={value}") - } +impl<'a> From<&'a str> for ParameterStr<'a> { + fn from(parameter: &'a str) -> Self { + let (key, value) = if let Some((key, value)) = parameter.split_once('=') { + let value = value + .strip_prefix('"') + .unwrap_or(value) + .strip_suffix('"') + .unwrap_or(value); + (key, Some(value)) } else { - write!(f, "{key}") + (parameter, None) + }; + let key = ParameterKeyStr(key); + ParameterStr { + parameter, + key, + value, } } } +impl<'a> PartialEq for Parameter<'a> { + fn eq(&self, other: &Self) -> bool { + // Note we don't compare parameter because we want hyphen-dash insensitivity for the key + self.key == other.key && self.value == other.value + } +} + +impl<'a> PartialEq for ParameterKeyStr<'a> { + fn eq(&self, other: &Self) -> bool { + ParameterKey(self.0.as_bytes()) == ParameterKey(other.0.as_bytes()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -293,9 +334,6 @@ mod tests { p.push(non_utf8_byte[0]); let p = Parameter::from(&p); assert_eq!(p.value, Some(non_utf8_byte.as_slice())); - - // lossy replacement sanity check - assert_eq!(p.value_lossy(), char::REPLACEMENT_CHARACTER.to_string()); } #[test] @@ -334,14 +372,9 @@ mod tests { let kargs = Cmdline::from(b"foo=bar,bar2 baz=fuz wiz".as_slice()); let mut iter = kargs.iter(); - assert_eq!(iter.next(), Some(Parameter::new_kv(b"foo", b"bar,bar2"))); - - assert_eq!( - iter.next(), - Some(Parameter::new_kv(b"baz", b"fuz".as_slice())) - ); - - assert_eq!(iter.next(), Some(Parameter::new_key(b"wiz"))); + assert_eq!(iter.next(), Some(Parameter::from(b"foo=bar,bar2"))); + assert_eq!(iter.next(), Some(Parameter::from(b"baz=fuz"))); + assert_eq!(iter.next(), Some(Parameter::from(b"wiz"))); assert_eq!(iter.next(), None); // Test the find API @@ -483,4 +516,70 @@ mod tests { let kargs = Cmdline::from(&invalid_utf8); assert!(kargs.require_value_of_utf8("invalid").is_err()); } + + #[test] + fn test_find_str() { + let kargs = Cmdline::from(b"foo=bar baz=qux switch rd.break".as_slice()); + let p = kargs.find_str("foo").unwrap(); + assert_eq!(p, ParameterStr::from("foo=bar")); + assert_eq!(p.as_ref(), "foo=bar"); + let p = kargs.find_str("rd.break").unwrap(); + assert_eq!(p, ParameterStr::from("rd.break")); + assert!(kargs.find_str("missing").is_none()); + } + + #[test] + fn test_find_all_str() { + let kargs = + Cmdline::from(b"foo=bar rd.foo=a rd.bar=b rd.baz rd.qux=c notrd.val=d".as_slice()); + let mut rd_args: Vec<_> = kargs.find_all_starting_with_str("rd.").collect(); + rd_args.sort_by(|a, b| a.key.0.cmp(b.key.0)); + assert_eq!(rd_args.len(), 4); + assert_eq!(rd_args[0], ParameterStr::from("rd.bar=b")); + assert_eq!(rd_args[1], ParameterStr::from("rd.baz")); + assert_eq!(rd_args[2], ParameterStr::from("rd.foo=a")); + assert_eq!(rd_args[3], ParameterStr::from("rd.qux=c")); + } + + #[test] + fn test_param_to_str() { + let p = Parameter::from("foo=bar"); + let p_str = p.to_str().unwrap(); + assert_eq!(p_str, ParameterStr::from("foo=bar")); + let non_utf8_byte = b"\xff"; + let mut p_u8 = b"foo=".to_vec(); + p_u8.push(non_utf8_byte[0]); + let p = Parameter::from(&p_u8); + assert!(p.to_str().is_none()); + } + + #[test] + fn test_param_key_str_eq() { + let k1 = ParameterKeyStr("a-b"); + let k2 = ParameterKeyStr("a_b"); + assert_eq!(k1, k2); + let k1 = ParameterKeyStr("a-b"); + let k2 = ParameterKeyStr("a-c"); + assert_ne!(k1, k2); + } + + #[test] + fn test_kargs_non_utf8() { + let non_utf8_val = b"an_invalid_key=\xff"; + let mut kargs_bytes = b"foo=bar ".to_vec(); + kargs_bytes.extend_from_slice(non_utf8_val); + kargs_bytes.extend_from_slice(b" baz=qux"); + let kargs = Cmdline::from(kargs_bytes.as_slice()); + + // We should be able to find the valid kargs + assert_eq!(kargs.find_str("foo").unwrap().value, Some("bar")); + assert_eq!(kargs.find_str("baz").unwrap().value, Some("qux")); + + // But we should not find the invalid one via find_str + assert!(kargs.find("an_invalid_key").unwrap().to_str().is_none()); + + // And even using the raw find, trying to convert it to_str will fail. + let raw_param = kargs.find("an_invalid_key").unwrap(); + assert_eq!(raw_param.value.unwrap(), b"\xff"); + } }