Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions crates/lib/src/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1668,17 +1668,15 @@ fn find_root_args_to_inherit(cmdline: &Cmdline, root_info: &Filesystem) -> Resul
.context("Parsing root= karg")?;
// If we have a root= karg, then use that
let (mount_spec, kargs) = if let Some(root) = root {
let rootflags = cmdline.find(crate::kernel_cmdline::ROOTFLAGS);
let inherit_kargs = cmdline.iter().filter(|arg| {
arg.key
.starts_with(crate::kernel_cmdline::INITRD_ARG_PREFIX)
});
let rootflags = cmdline.find_str(crate::kernel_cmdline::ROOTFLAGS);
let inherit_kargs =
cmdline.find_all_starting_with_str(crate::kernel_cmdline::INITRD_ARG_PREFIX);
(
root.to_owned(),
rootflags
.into_iter()
.chain(inherit_kargs)
.map(|p| p.to_string())
.map(|p| p.as_ref().to_owned())
.collect(),
)
} else {
Expand Down
223 changes: 161 additions & 62 deletions crates/lib/src/kernel_cmdline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use std::borrow::Cow;
use anyhow::Result;

/// This is used by dracut.
pub(crate) const INITRD_ARG_PREFIX: &[u8] = b"rd.";
pub(crate) const INITRD_ARG_PREFIX: &str = "rd.";
/// The kernel argument for configuring the rootfs flags.
pub(crate) const ROOTFLAGS: &[u8] = b"rootflags";
pub(crate) const ROOTFLAGS: &str = "rootflags";

/// A parsed kernel command line.
///
Expand Down Expand Up @@ -41,7 +41,7 @@ impl<'a> Cmdline<'a> {
/// Properly handles quoted values containing whitespace and splits on
/// unquoted whitespace characters. Parameters are parsed as either
/// key-only switches or key=value pairs.
pub fn iter(&'a self) -> impl Iterator<Item = Parameter<'a>> {
pub fn iter(&'a self) -> impl Iterator<Item = Parameter<'a>> + 'a {
let mut in_quotes = false;

self.0
Expand All @@ -63,6 +63,29 @@ impl<'a> Cmdline<'a> {
self.iter().find(|p| p.key == key)
}

/// Locate a kernel argument with the given key name that must be UTF-8.
///
/// Otherwise the same as [`Self::find`].
pub fn find_str(&'a self, key: &str) -> Option<ParameterStr<'a>> {
let key = ParameterKeyStr(key);
self.iter()
.filter_map(|p| p.to_str())
.find(move |p| p.key == key)
}

/// Find all kernel arguments starting with the given prefix which must be UTF-8.
/// Non-UTF8 values are ignored.
///
/// This is a variant of [`Self::find`].
pub fn find_all_starting_with_str(
&'a self,
prefix: &'a str,
) -> impl Iterator<Item = ParameterStr<'a>> + 'a {
self.iter()
.filter_map(|p| p.to_str())
.filter(move |p| p.key.0.starts_with(prefix))
}

/// Locate the value of the kernel argument with the given key name.
///
/// Returns the first value matching the given key, or `None` if not found.
Expand Down Expand Up @@ -121,47 +144,52 @@ impl<'a> From<&'a [u8]> for ParameterKey<'a> {
}
}

/// A single kernel command line parameter key that is known to be UTF-8.
///
/// Otherwise the same as [`ParameterKey`].
#[derive(Debug, Eq)]
pub(crate) struct ParameterKeyStr<'a>(&'a str);

impl<'a> From<&'a str> for ParameterKeyStr<'a> {
fn from(value: &'a str) -> Self {
Self(value)
}
}

/// A single kernel command line parameter.
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, Eq)]
pub(crate) struct Parameter<'a> {
/// The full original value
pub parameter: &'a [u8],
/// The parameter key as raw bytes
pub key: ParameterKey<'a>,
/// The parameter value as raw bytes, if present
pub value: Option<&'a [u8]>,
}

impl<'a> Parameter<'a> {
/// Create a new parameter with the provided key and value.
#[cfg(test)]
pub fn new_kv<'k: 'a, 'v: 'a>(key: &'k [u8], value: &'v [u8]) -> Self {
Self {
key: ParameterKey(key),
value: Some(value),
}
}

/// Create a new parameter with the provided key.
#[cfg(test)]
pub fn new_key(key: &'a [u8]) -> Self {
Self {
key: ParameterKey(key),
value: None,
}
}
/// A single kernel command line parameter.
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct ParameterStr<'a> {
/// The original value
pub parameter: &'a str,
/// The parameter key
pub key: ParameterKeyStr<'a>,
/// The parameter value, if present
pub value: Option<&'a str>,
}

/// Returns the key as a lossy UTF-8 string.
///
/// Invalid UTF-8 sequences are replaced with the Unicode replacement character.
pub fn key_lossy(&self) -> String {
String::from_utf8_lossy(&self.key).to_string()
impl<'a> Parameter<'a> {
pub fn to_str(&self) -> Option<ParameterStr<'a>> {
let Ok(parameter) = std::str::from_utf8(self.parameter) else {
return None;
};
Some(ParameterStr::from(parameter))
}
}

/// Returns the value as a lossy UTF-8 string.
///
/// Invalid UTF-8 sequences are replaced with the Unicode replacement character.
/// Returns an empty string if no value is present.
pub fn value_lossy(&self) -> String {
String::from_utf8_lossy(self.value.unwrap_or(&[])).to_string()
impl<'a> AsRef<str> for ParameterStr<'a> {
fn as_ref(&self) -> &str {
self.parameter
}
}

Expand All @@ -177,6 +205,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a T> for Parameter<'a> {

match equals {
None => Self {
parameter: input,
key: ParameterKey(input),
value: None,
},
Expand All @@ -196,6 +225,7 @@ impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a T> for Parameter<'a> {
.unwrap_or(value);

Self {
parameter: input,
key,
value: Some(value),
}
Expand Down Expand Up @@ -228,29 +258,40 @@ impl PartialEq for ParameterKey<'_> {
}
}

impl std::fmt::Display for Parameter<'_> {
/// Formats the parameter for display.
///
/// Key-only parameters are displayed as just the key.
/// Key-value parameters are displayed as `key=value`.
/// Values containing whitespace are automatically quoted.
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let key = self.key_lossy();

if self.value.is_some() {
let value = self.value_lossy();

if value.chars().any(|c| c.is_ascii_whitespace()) {
write!(f, "{key}=\"{value}\"")
} else {
write!(f, "{key}={value}")
}
impl<'a> From<&'a str> for ParameterStr<'a> {
fn from(parameter: &'a str) -> Self {
let (key, value) = if let Some((key, value)) = parameter.split_once('=') {
let value = value
.strip_prefix('"')
.unwrap_or(value)
.strip_suffix('"')
.unwrap_or(value);
(key, Some(value))
} else {
write!(f, "{key}")
(parameter, None)
};
let key = ParameterKeyStr(key);
ParameterStr {
parameter,
key,
value,
}
}
}

impl<'a> PartialEq for Parameter<'a> {
fn eq(&self, other: &Self) -> bool {
// Note we don't compare parameter because we want hyphen-dash insensitivity for the key
self.key == other.key && self.value == other.value
}
}

impl<'a> PartialEq for ParameterKeyStr<'a> {
fn eq(&self, other: &Self) -> bool {
ParameterKey(self.0.as_bytes()) == ParameterKey(other.0.as_bytes())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -293,9 +334,6 @@ mod tests {
p.push(non_utf8_byte[0]);
let p = Parameter::from(&p);
assert_eq!(p.value, Some(non_utf8_byte.as_slice()));

// lossy replacement sanity check
assert_eq!(p.value_lossy(), char::REPLACEMENT_CHARACTER.to_string());
}

#[test]
Expand Down Expand Up @@ -334,14 +372,9 @@ mod tests {
let kargs = Cmdline::from(b"foo=bar,bar2 baz=fuz wiz".as_slice());
let mut iter = kargs.iter();

assert_eq!(iter.next(), Some(Parameter::new_kv(b"foo", b"bar,bar2")));

assert_eq!(
iter.next(),
Some(Parameter::new_kv(b"baz", b"fuz".as_slice()))
);

assert_eq!(iter.next(), Some(Parameter::new_key(b"wiz")));
assert_eq!(iter.next(), Some(Parameter::from(b"foo=bar,bar2")));
assert_eq!(iter.next(), Some(Parameter::from(b"baz=fuz")));
assert_eq!(iter.next(), Some(Parameter::from(b"wiz")));
assert_eq!(iter.next(), None);

// Test the find API
Expand Down Expand Up @@ -483,4 +516,70 @@ mod tests {
let kargs = Cmdline::from(&invalid_utf8);
assert!(kargs.require_value_of_utf8("invalid").is_err());
}

#[test]
fn test_find_str() {
let kargs = Cmdline::from(b"foo=bar baz=qux switch rd.break".as_slice());
let p = kargs.find_str("foo").unwrap();
assert_eq!(p, ParameterStr::from("foo=bar"));
assert_eq!(p.as_ref(), "foo=bar");
let p = kargs.find_str("rd.break").unwrap();
assert_eq!(p, ParameterStr::from("rd.break"));
assert!(kargs.find_str("missing").is_none());
}

#[test]
fn test_find_all_str() {
let kargs =
Cmdline::from(b"foo=bar rd.foo=a rd.bar=b rd.baz rd.qux=c notrd.val=d".as_slice());
let mut rd_args: Vec<_> = kargs.find_all_starting_with_str("rd.").collect();
rd_args.sort_by(|a, b| a.key.0.cmp(b.key.0));
assert_eq!(rd_args.len(), 4);
assert_eq!(rd_args[0], ParameterStr::from("rd.bar=b"));
assert_eq!(rd_args[1], ParameterStr::from("rd.baz"));
assert_eq!(rd_args[2], ParameterStr::from("rd.foo=a"));
assert_eq!(rd_args[3], ParameterStr::from("rd.qux=c"));
}

#[test]
fn test_param_to_str() {
let p = Parameter::from("foo=bar");
let p_str = p.to_str().unwrap();
assert_eq!(p_str, ParameterStr::from("foo=bar"));
let non_utf8_byte = b"\xff";
let mut p_u8 = b"foo=".to_vec();
p_u8.push(non_utf8_byte[0]);
let p = Parameter::from(&p_u8);
assert!(p.to_str().is_none());
}

#[test]
fn test_param_key_str_eq() {
let k1 = ParameterKeyStr("a-b");
let k2 = ParameterKeyStr("a_b");
assert_eq!(k1, k2);
let k1 = ParameterKeyStr("a-b");
let k2 = ParameterKeyStr("a-c");
assert_ne!(k1, k2);
}

#[test]
fn test_kargs_non_utf8() {
let non_utf8_val = b"an_invalid_key=\xff";
let mut kargs_bytes = b"foo=bar ".to_vec();
kargs_bytes.extend_from_slice(non_utf8_val);
kargs_bytes.extend_from_slice(b" baz=qux");
let kargs = Cmdline::from(kargs_bytes.as_slice());

// We should be able to find the valid kargs
assert_eq!(kargs.find_str("foo").unwrap().value, Some("bar"));
assert_eq!(kargs.find_str("baz").unwrap().value, Some("qux"));

// But we should not find the invalid one via find_str
assert!(kargs.find("an_invalid_key").unwrap().to_str().is_none());

// And even using the raw find, trying to convert it to_str will fail.
let raw_param = kargs.find("an_invalid_key").unwrap();
assert_eq!(raw_param.value.unwrap(), b"\xff");
}
}