From 2417c9287b6ab12f3cde9ebd9787e5acb6778a69 Mon Sep 17 00:00:00 2001 From: t7phy Date: Thu, 18 Jul 2024 16:45:08 +0200 Subject: [PATCH] Migrate `Channel` methods to new definition --- pineappl/src/boc.rs | 129 ++++++++++++++++++++++---------------------- 1 file changed, 63 insertions(+), 66 deletions(-) diff --git a/pineappl/src/boc.rs b/pineappl/src/boc.rs index c0a1794d..9965459d 100644 --- a/pineappl/src/boc.rs +++ b/pineappl/src/boc.rs @@ -311,22 +311,22 @@ impl Channel { // sort `entry` because the ordering doesn't matter and because it makes it easier to // compare `Channel` objects with each other - entry.sort_by(|x, y| (x.0, x.1).cmp(&(y.0, y.1))); + entry.sort_by(|x, y| x.0.cmp(&y.0)); Self { entry: entry .into_iter() .coalesce(|lhs, rhs| { // sum the factors of repeated elements - if (lhs.0, lhs.1) == (rhs.0, rhs.1) { - Ok((lhs.0, lhs.1, lhs.2 + rhs.2)) + if lhs.0 == rhs.0 { + Ok((lhs.0, lhs.1 + rhs.1)) } else { Err((lhs, rhs)) } }) // filter zeros // TODO: find a better than to hardcode the epsilon limit - .filter(|&(_, _, f)| !approx_eq!(f64, f.abs(), 0.0, epsilon = 1e-14)) + .filter(|&(_, f)| !approx_eq!(f64, f.abs(), 0.0, epsilon = 1e-14)) .collect(), } } @@ -347,17 +347,22 @@ impl Channel { /// assert_eq!(entry, channel![2, 11, 1.0; -2, 11, -1.0; 1, 11, -1.0; -1, 11, 1.0]); /// ``` pub fn translate(entry: &Self, translator: &dyn Fn(i32) -> Vec<(i32, f64)>) -> Self { - let mut tuples = Vec::new(); + let mut result = Vec::new(); - for &(a, b, factor) in &entry.entry { - for (aid, af) in translator(a) { - for (bid, bf) in translator(b) { - tuples.push((aid, bid, factor * af * bf)); - } + for &(pids, factor) in &entry.entry { + for tuples in pids + .iter() + .map(|&pid| translator(pid)) + .multi_cartesian_product() + { + result.push(( + tuples.iter().map(|&(pid, _)| pid).collect(), + tuples.iter().map(|(_, f)| f).product::(), + )); } } - Self::new(tuples) + Self::new(result) } /// Returns a tuple representation of this entry. @@ -377,11 +382,11 @@ impl Channel { &self.entry } - /// Creates a new object with the initial states transposed. - #[must_use] - pub fn transpose(&self) -> Self { - Self::new(self.entry.iter().map(|(a, b, c)| (*b, *a, *c)).collect()) - } + // /// Creates a new object with the initial states transposed. + // #[must_use] + // pub fn transpose(&self) -> Self { + // Self::new(self.entry.iter().map(|(a, b, c)| (*b, *a, *c)).collect()) + // } /// If `other` is the same channel when only comparing PIDs and neglecting the factors, return /// the number `f1 / f2`, where `f1` is the factor from `self` and `f2` is the factor from @@ -392,10 +397,10 @@ impl Channel { /// ```rust /// use pineappl::boc::Channel; /// - /// let entry1 = Channel::new(vec![(2, 2, 2.0), (4, 4, 2.0)]); - /// let entry2 = Channel::new(vec![(4, 4, 1.0), (2, 2, 1.0)]); - /// let entry3 = Channel::new(vec![(3, 4, 1.0), (2, 2, 1.0)]); - /// let entry4 = Channel::new(vec![(4, 3, 1.0), (2, 3, 2.0)]); + /// let entry1 = Channel::new(vec![(vec![2, 2], 2.0), (vec![4, 4], 2.0)]); + /// let entry2 = Channel::new(vec![(vec![4, 4], 1.0), (vec![2, 2], 1.0)]); + /// let entry3 = Channel::new(vec![(vec![3, 4], 1.0), (vec![2, 2], 1.0)]); + /// let entry4 = Channel::new(vec![(vec![4, 3], 1.0), (vec![2, 3], 2.0)]); /// /// assert_eq!(entry1.common_factor(&entry2), Some(2.0)); /// assert_eq!(entry1.common_factor(&entry3), None); @@ -411,7 +416,7 @@ impl Channel { .entry .iter() .zip(&other.entry) - .map(|(a, b)| ((a.0 == b.0) && (a.1 == b.1)).then_some(a.2 / b.2)) + .map(|(a, b)| (a == b).then_some(a.1 / b.1)) .collect(); result.and_then(|factors| { @@ -436,51 +441,43 @@ impl FromStr for Channel { type Err = ParseChannelError; fn from_str(s: &str) -> Result { - Ok(Self::new( - s.split('+') - .map(|sub| { - sub.split_once('*').map_or_else( - || Err(ParseChannelError(format!("missing '*' in '{sub}'"))), - |(factor, pids)| { - let tuple = pids.split_once(',').map_or_else( - || Err(ParseChannelError(format!("missing ',' in '{pids}'"))), - |(a, b)| { - Ok(( - a.trim() - .strip_prefix('(') - .ok_or_else(|| { - ParseChannelError(format!( - "missing '(' in '{pids}'" - )) - })? - .trim() - .parse::() - .map_err(|err| ParseChannelError(err.to_string()))?, - b.trim() - .strip_suffix(')') - .ok_or_else(|| { - ParseChannelError(format!( - "missing ')' in '{pids}'" - )) - })? - .trim() - .parse::() - .map_err(|err| ParseChannelError(err.to_string()))?, - )) - }, - )?; - - Ok(( - tuple.0, - tuple.1, - str::parse::(factor.trim()) - .map_err(|err| ParseChannelError(err.to_string()))?, - )) - }, - ) - }) - .collect::>()?, - )) + let result: Vec<_> = s + .split('+') + .map(|sub| { + sub.split_once('*').map_or_else( + || Err(ParseChannelError(format!("missing '*' in '{sub}'"))), + |(factor, pids)| { + let vector: Vec<_> = pids + .strip_prefix('(') + .ok_or_else(|| ParseChannelError(format!("missing '(' in '{pids}'")))? + .strip_suffix(')') + .ok_or_else(|| ParseChannelError(format!("missing ')' in '{pids}'")))? + .split(',') + .map(|pid| { + Ok(pid + .trim() + .parse::() + .map_err(|err| ParseChannelError(err.to_string()))?) + }) + .collect::>()?; + + Ok(( + vector, + str::parse::(factor.trim()) + .map_err(|err| ParseChannelError(err.to_string()))?, + )) + }, + ) + }) + .collect::>()?; + + if !result.iter().map(|(pids, _)| pids.len()).all_equal() { + return Err(ParseChannelError(format!( + "PID tuples have different lengths" + ))); + } + + Ok(Self::new(result)) } }