diff --git a/src/build.rs b/src/build.rs index 1e11e63..e26dc4d 100644 --- a/src/build.rs +++ b/src/build.rs @@ -21,6 +21,25 @@ pub struct RunningCargo { message_iter: MessageIter>, } +#[derive(Debug, Copy, Clone, clap::ValueEnum)] +pub enum CargoCommand { + Build, + Test, + Run, + Bench, +} + +impl CargoCommand { + pub fn to_str(&self) -> &str { + match self { + CargoCommand::Build => "build", + CargoCommand::Test => "test", + CargoCommand::Run => "run", + CargoCommand::Bench => "bench", + } + } +} + impl RunningCargo { pub fn messages(&mut self) -> &mut MessageIter> { &mut self.message_iter @@ -121,21 +140,53 @@ fn parse_cargo_args(cargo_args: Vec) -> CargoArgs { "--release" => { log::warn!("Do not pass `--release` manually, it will be added automatically by `cargo-pgo`"); } - // Skip `--message-format`, we need it to be JSON. - "--message-format" => { - log::warn!("Do not pass `--message-format` manually, it will be added automatically by `cargo-pgo`"); - iterator.next(); // skip flag value + _ => { + if get_key_value("--message-format", arg.as_str(), &mut iterator).is_some() { + // Skip `--message-format`, we need it to be JSON. + log::warn!("Do not pass `--message-format` manually, it will be added automatically by `cargo-pgo`"); + } else if let Some(value) = get_key_value("--target", arg.as_str(), &mut iterator) { + // Check if `--target` was passed + args.contains_target = true; + args.filtered.push("--target".to_string()); + if let Some(value) = value { + args.filtered.push(value); + } + } else { + args.filtered.push(arg); + } } - "--target" => { - args.contains_target = true; - args.filtered.push(arg); - } - _ => args.filtered.push(arg), } } args } +/// Parses a `--key=` or `--key ` key/value CLI argument pair. +fn get_key_value>( + key: &str, + arg: &str, + iter: &mut Iter, +) -> Option> { + // A different argument was passed, nothing to be seen here + if !arg.starts_with(key) { + return None; + } + // --key was passed exactly, we should extract the value from the following argument + if arg == key { + let value = iter.next(); + return Some(value); + } + + // --key was passed, let's try to split it into --key=value + if let Some((parsed_key, value)) = arg.split_once('=') { + // if --keyfoo=value was passed, ignore it + if parsed_key == key { + return Some(Some(value.to_string())); + } + } + + None +} + pub fn handle_metadata_message(message: Message) { let stdout = std::io::stdout(); let mut stdout = stdout.lock(); @@ -185,10 +236,10 @@ pub fn get_artifact_kind(artifact: &Artifact) -> &str { #[cfg(test)] mod tests { - use crate::build::parse_cargo_args; + use crate::build::{get_key_value, parse_cargo_args}; #[test] - fn test_parse_cargo_args_filter_release() { + fn parse_cargo_args_filter_release() { let args = parse_cargo_args(vec![ "foo".to_string(), "--release".to_string(), @@ -198,7 +249,7 @@ mod tests { } #[test] - fn test_parse_cargo_args_filter_message_format() { + fn parse_cargo_args_filter_message_format() { let args = parse_cargo_args(vec![ "foo".to_string(), "--message-format".to_string(), @@ -209,7 +260,17 @@ mod tests { } #[test] - fn test_parse_cargo_args_find_target() { + fn parse_cargo_args_filter_message_format_equals() { + let args = parse_cargo_args(vec![ + "foo".to_string(), + "--message-format=json".to_string(), + "bar".to_string(), + ]); + assert_eq!(args.filtered, vec!["foo".to_string(), "bar".to_string()]); + } + + #[test] + fn parse_cargo_args_find_target() { let args = parse_cargo_args(vec![ "--target".to_string(), "x64".to_string(), @@ -221,23 +282,54 @@ mod tests { ); assert!(args.contains_target); } -} -#[derive(Debug, Copy, Clone, clap::ValueEnum)] -pub enum CargoCommand { - Build, - Test, - Run, - Bench, -} + #[test] + fn parse_cargo_args_find_target_equals() { + let args = parse_cargo_args(vec!["--target=x64".to_string(), "bar".to_string()]); + assert_eq!( + args.filtered, + vec!["--target".to_string(), "x64".to_string(), "bar".to_string()] + ); + assert!(args.contains_target); + } -impl CargoCommand { - pub fn to_str(&self) -> &str { - match self { - CargoCommand::Build => "build", - CargoCommand::Test => "test", - CargoCommand::Run => "run", - CargoCommand::Bench => "bench", - } + #[test] + fn get_key_value_wrong_key() { + assert_eq!( + get_key_value("--foo", "--bar", &mut std::iter::empty()), + None + ); + } + + #[test] + fn get_key_value_exact_key_missing_value() { + assert_eq!( + get_key_value("--foo", "--foo", &mut std::iter::empty()), + Some(None) + ); + } + + #[test] + fn get_key_value_exact_key_value() { + assert_eq!( + get_key_value("--foo", "--foo", &mut vec!["bar".to_string()].into_iter()), + Some(Some("bar".to_string())) + ); + } + + #[test] + fn get_key_value_equals_wrong_prefix() { + assert_eq!( + get_key_value("--foo", "--foox=bar", &mut std::iter::empty()), + None + ); + } + + #[test] + fn get_key_value_equals() { + assert_eq!( + get_key_value("--foo", "--foo=bar", &mut std::iter::empty()), + Some(Some("bar".to_string())) + ); } }