diff --git a/src/dot.rs b/src/dot.rs index cefaf440..a086a750 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -22,7 +22,7 @@ The [`EGraph::dot`](EGraph::dot()) method creates `Dot`s. # Example ```no_run -use egg::{*, rewrite as rw}; +use egg::legacy::{*, rewrite as rw}; let rules = &[ rw!("mul-commutes"; "(* ?x ?y)" => "(* ?y ?x)"), @@ -192,17 +192,19 @@ where writeln!(f, " {}", line)?; } + let classes = self.egraph.generate_class_nodes(); + // define all the nodes, clustered by eclass - for class in self.egraph.classes() { - writeln!(f, " subgraph cluster_{} {{", class.id)?; + for (&id, class) in &classes { + writeln!(f, " subgraph cluster_{} {{", id)?; writeln!(f, " style=dotted")?; for (i, node) in class.iter().enumerate() { - writeln!(f, " {}.{}[label = \"{}\"]", class.id, i, node)?; + writeln!(f, " {}.{}[label = \"{}\"]", id, i, node)?; } writeln!(f, " }}")?; } - for class in self.egraph.classes() { + for (&id, class) in &classes { for (i_in_class, node) in class.iter().enumerate() { let mut arg_i = 0; node.try_for_each(|child| { @@ -210,19 +212,19 @@ where let (anchor, label) = self.edge(arg_i, node.len()); let child_leader = self.egraph.find(child); - if child_leader == class.id { + if child_leader == id { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, class.id, i_in_class, class.id, label + id, i_in_class, anchor, id, i_in_class, id, label )?; } else { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.0 [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, child, child_leader, label + id, i_in_class, anchor, child, child_leader, label )?; } arg_i += 1; diff --git a/src/eclass.rs b/src/eclass.rs index 5f74b2c2..feb8aee0 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -7,46 +7,29 @@ use crate::*; #[non_exhaustive] #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -pub struct EClass { +pub struct EClass { /// This eclass's id. pub id: Id, - /// The equivalent enodes in this equivalence class. - pub nodes: Vec, /// The analysis data associated with this eclass. /// /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The parent enodes and their original Ids. - pub(crate) parents: Vec<(L, Id)>, + /// The original Ids of parent enodes. + pub(crate) parents: Vec, } -impl EClass { - /// Returns `true` if the `eclass` is empty. - pub fn is_empty(&self) -> bool { - self.nodes.is_empty() - } - - /// Returns the number of enodes in this eclass. - pub fn len(&self) -> usize { - self.nodes.len() - } - - /// Iterates over the enodes in this eclass. - pub fn iter(&self) -> impl ExactSizeIterator { - self.nodes.iter() - } - - /// Iterates over the parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator { - self.parents.iter().map(|(node, id)| (node, *id)) +impl EClass { + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() } } -impl EClass { +impl EMClass { /// Iterates over the childless enodes in this eclass. pub fn leaves(&self) -> impl Iterator { - self.nodes.iter().filter(|&n| n.is_leaf()) + self.iter().filter(|&n| n.is_leaf()) } /// Asserts that the childless enodes in this eclass are unique. @@ -64,4 +47,24 @@ impl EClass { ); } } + + /// The equivalent enodes in this equivalence class. + pub fn nodes(&self) -> &[L] { + &self.data.1 + } + + /// Returns `true` if the `eclass` is empty. + pub fn is_empty(&self) -> bool { + self.nodes().is_empty() + } + + /// Returns the number of enodes in this eclass. + pub fn len(&self) -> usize { + self.nodes().len() + } + + /// Iterates over the enodes in this eclass. + pub fn iter(&self) -> impl ExactSizeIterator { + self.nodes().iter() + } } diff --git a/src/egraph.rs b/src/egraph.rs index 6af452b2..fb0600c8 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1,7 +1,8 @@ +use crate::generic_analysis::{Analysis, Just}; use crate::*; use std::{ borrow::BorrowMut, - fmt::{self, Debug, Display}, + fmt::{self, Debug}, }; #[cfg(feature = "serde-1")] @@ -56,7 +57,9 @@ pub struct EGraph> { pub analysis: N, /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, - unionfind: UnionFind, + pub(crate) unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + pub(crate) nodes: Vec, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. @@ -64,8 +67,7 @@ pub struct EGraph> { memo: HashMap, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + pending: Vec, #[cfg_attr( feature = "serde-1", serde(bound( @@ -73,10 +75,7 @@ pub struct EGraph> { deserialize = "N::Data: for<'a> Deserialize<'a>", )) )] - pub(crate) classes: HashMap>, - #[cfg_attr(feature = "serde-1", serde(skip))] - #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))] - pub(crate) classes_by_op: HashMap>, + pub(crate) classes: HashMap>, /// Whether or not reading operation are allowed on this e-graph. /// Mutating operations will set this to `false`, and /// [`EGraph::rebuild`] will set it to true. @@ -86,11 +85,6 @@ pub struct EGraph> { pub clean: bool, } -#[cfg(feature = "serde-1")] -fn default_classes_by_op() -> HashMap> { - HashMap::default() -} - impl + Default> Default for EGraph { fn default() -> Self { Self::new(N::default()) @@ -114,25 +108,33 @@ impl> EGraph { analysis, classes: Default::default(), unionfind: Default::default(), + nodes: Default::default(), clean: false, explain: None, pending: Default::default(), memo: Default::default(), - analysis_pending: Default::default(), - classes_by_op: Default::default(), } } /// Returns an iterator over the eclasses in the egraph. - pub fn classes(&self) -> impl ExactSizeIterator> { + pub fn classes(&self) -> impl ExactSizeIterator> { self.classes.values() } /// Returns an mutating iterator over the eclasses in the egraph. - pub fn classes_mut(&mut self) -> impl ExactSizeIterator> { + pub fn classes_mut(&mut self) -> impl ExactSizeIterator> { self.classes.values_mut() } + pub(crate) fn class_and_rest(&mut self, id: Id) -> (&mut EClass, &mut N, &UnionFind) { + let id = self.find_mut(id); + let class = self + .classes + .get_mut(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)); + (class, &mut self.analysis, &self.unionfind) + } + /// Returns `true` if the egraph is empty /// # Example /// ``` @@ -166,9 +168,17 @@ impl> EGraph { self.memo.len() } - /// Iterates over the classes, returning the total number of nodes. + /// Returns an iterator over the nodes in the egraph and there uncanonical `Id`s. + pub(crate) fn nodes(&self) -> impl ExactSizeIterator { + self.nodes + .iter() + .enumerate() + .map(|(id, v)| (Id::from(id), v)) + } + + /// Returns the number of nodes in the egraph pub fn total_number_of_nodes(&self) -> usize { - self.classes().map(|c| c.len()).sum() + self.nodes.len() } /// Returns the number of eclasses in the egraph. @@ -214,12 +224,14 @@ impl> EGraph { /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { - if let Some(explain) = &self.explain { - let egraph = Self::new(analysis); - explain.populate_enodes(egraph) - } else { + if self.explain.is_none() { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } + let mut egraph = Self::new(analysis); + for node in &self.nodes { + egraph.add(node.clone()); + } + egraph } /// Performs the union between two egraphs. @@ -279,12 +291,23 @@ impl> EGraph { /// are not captured in the intersection. /// The runtime of this algorithm is O(|E1| * |E2|), where |E1| and |E2| are the number of enodes in each egraph. pub fn egraph_intersect(&self, other: &EGraph, analysis: N) -> EGraph { + let class_to_nodes = self.generate_class_nodes(); + let other_class_to_nodes = other.generate_class_nodes(); + let mut product_map: HashMap<(Id, Id), Id> = Default::default(); let mut enodes = vec![]; - for class1 in self.classes() { - for class2 in other.classes() { - self.intersect_classes(other, &mut enodes, class1.id, class2.id, &mut product_map); + for (&id1, nodes1) in &class_to_nodes { + for (&id2, nodes2) in &other_class_to_nodes { + self.intersect_classes( + other, + &mut enodes, + id1, + nodes1, + id2, + nodes2, + &mut product_map, + ); } } @@ -306,12 +329,14 @@ impl> EGraph { other: &EGraph, res: &mut Vec<(L, Id)>, class1: Id, + nodes1: &[L], class2: Id, + nodes2: &[L], product_map: &mut HashMap<(Id, Id), Id>, ) { let res_id = Self::get_product_id(class1, class2, product_map); - for node1 in &self.classes[&class1].nodes { - for node2 in &other.classes[&class2].nodes { + for node1 in nodes1 { + for node2 in nodes2 { if node1.matches(node2) { let children1 = node1.children(); let children2 = node2.children(); @@ -339,20 +364,33 @@ impl> EGraph { /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) pub fn id_to_expr(&self, id: Id) -> RecExpr { - if let Some(explain) = &self.explain { - explain.node_to_recexpr(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id } /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { - if let Some(explain) = &self.explain { - explain.node(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } + &self.nodes[usize::from(id)] } /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. @@ -360,11 +398,36 @@ impl> EGraph { /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { - if let Some(explain) = &self.explain { - explain.node_to_pattern(id, substitutions) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id"); + let mut res = Default::default(); + let mut subst = Default::default(); + let mut cache = Default::default(); + self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + (Pattern::new(res), subst) + } + + fn id_to_pattern_internal( + &self, + res: &mut PatternAst, + node_id: Id, + var_substitutions: &HashMap, + subst: &mut Subst, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let var = format!("?{}", node_id).parse().unwrap(); + subst.insert(var, *existing); + res.add(ENodeOrVar::Var(var)) + } else { + let new_node = self.id_to_node(node_id).clone().map_children(|child| { + self.id_to_pattern_internal(res, child, var_substitutions, subst, cache) + }); + res.add(ENodeOrVar::ENode(new_node)) + }; + cache.insert(node_id, res_id); + res_id } /// Get all the unions ever found in the egraph in terms of enode ids. @@ -390,8 +453,10 @@ impl> EGraph { /// Get the number of congruences between nodes in the egraph. /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_congr::(&self.classes, &self.unionfind) + if let Some(explain) = &mut self.explain { + explain + .with_nodes(&self.nodes) + .get_num_congr::(&self.classes, &self.unionfind) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -399,8 +464,8 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_nodes() + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).get_num_nodes() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -438,7 +503,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -461,7 +531,7 @@ impl> EGraph { /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -475,7 +545,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_nodes(&self.nodes).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -498,7 +568,12 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain.with_nodes(&self.nodes).explain_equivalence::( + left, + right, + &mut self.unionfind, + &self.classes, + ) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } @@ -544,7 +619,7 @@ impl> EGraph { /// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class. impl> std::ops::Index for EGraph { - type Output = EClass; + type Output = EClass; fn index(&self, id: Id) -> &Self::Output { let id = self.find(id); self.classes @@ -557,10 +632,7 @@ impl> std::ops::Index for EGraph { /// reference to the e-class. impl> std::ops::IndexMut for EGraph { fn index_mut(&mut self, id: Id) -> &mut Self::Output { - let id = self.find_mut(id); - self.classes - .get_mut(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.class_and_rest(id).0 } } @@ -586,7 +658,7 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); @@ -624,7 +696,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the + /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -744,7 +816,7 @@ impl> EGraph { /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// - /// # Example + /// ## Example /// ``` /// # use egg::*; /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); @@ -759,6 +831,25 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` + /// + /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// produce an expression with equivalent but not necessarily identical children + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = EGraph::default().with_explanations_disabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.union(a, b); + /// egraph.rebuild(); + /// + /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + /// + /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); + /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); + /// ``` pub fn add_uncanonical(&mut self, mut enode: L) -> Id { let original = enode.clone(); if let Some(existing_id) = self.lookup_internal(&mut enode) { @@ -769,7 +860,9 @@ impl> EGraph { *existing_explain } else { let new_id = self.unionfind.make_set(); - explain.add(original, new_id, new_id); + explain.add(original.clone(), new_id, new_id); + debug_assert_eq!(Id::from(self.nodes.len()), new_id); + self.nodes.push(original); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -778,37 +871,38 @@ impl> EGraph { existing_id } } else { - let id = self.make_new_eclass(enode); + let id = self.make_new_eclass(enode, original.clone()); if let Some(explain) = self.explain.as_mut() { explain.add(original, id, id); } // now that we updated explanations, run the analysis for the new eclass - N::modify(self, id); + N::post_make(Just::new(self), id); self.clean = false; id } } /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L) -> Id { + fn make_new_eclass(&mut self, enode: L, original: L) -> Id { let id = self.unionfind.make_set(); log::trace!(" ...adding to {}", id); let class = EClass { id, - nodes: vec![enode.clone()], - data: N::make(self, &enode), + data: N::make(Just::new(self), &original), parents: Default::default(), }; + debug_assert_eq!(Id::from(self.nodes.len()), id); + self.nodes.push(original); + // add this enode to the parent lists of its children enode.for_each(|child| { - let tup = (enode.clone(), id); - self[child].parents.push(tup); + self[child].parents.push(id); }); // TODO is this needed? - self.pending.push((enode.clone(), id)); + self.pending.push(id); self.classes.insert(id, class); assert!(self.memo.insert(enode, id).is_none()); @@ -816,32 +910,6 @@ impl> EGraph { id } - /// Checks whether two [`RecExpr`]s are equivalent. - /// Returns a list of id where both expression are represented. - /// In most cases, there will none or exactly one id. - /// - pub fn equivs(&self, expr1: &RecExpr, expr2: &RecExpr) -> Vec { - let pat1 = Pattern::from(expr1.as_ref()); - let pat2 = Pattern::from(expr2.as_ref()); - let matches1 = pat1.search(self); - trace!("Matches1: {:?}", matches1); - - let matches2 = pat2.search(self); - trace!("Matches2: {:?}", matches2); - - let mut equiv_eclasses = Vec::new(); - - for m1 in &matches1 { - for m2 in &matches2 { - if self.find(m1.eclass) == self.find(m2.eclass) { - equiv_eclasses.push(m1.eclass) - } - } - } - - equiv_eclasses - } - /// Given two patterns and a substitution, add the patterns /// and union them. /// @@ -911,7 +979,7 @@ impl> EGraph { rule: Option, any_new_rhs: bool, ) -> bool { - N::pre_union(self, enode_id1, enode_id2, &rule); + N::pre_union(Just::new(self), enode_id1, enode_id2, &rule); self.clean = false; let mut id1 = self.find_mut(enode_id1); @@ -932,7 +1000,7 @@ impl> EGraph { } if let Some(explain) = &mut self.explain { - explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); + explain.union(enode_id1, enode_id2, rule.clone().unwrap(), any_new_rhs); } // make id1 the new root @@ -943,35 +1011,19 @@ impl> EGraph { let class1 = self.classes.get_mut(&id1).unwrap(); assert_eq!(id1, class1.id); - self.pending.extend(class2.parents.iter().cloned()); - let did_merge = self.analysis.merge(&mut class1.data, class2.data); - if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().cloned()); - } - if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().cloned()); - } - - concat_vecs(&mut class1.nodes, class2.nodes); - concat_vecs(&mut class1.parents, class2.parents); - - N::modify(self, id1); + self.pending.extend(class2.parents.iter().copied()); + class1.parents.extend(&class2.parents); + N::merge( + Just::new(self), + id1, + id2, + class2.data, + &class2.parents, + &rule, + ); true } - /// Update the analysis data of an e-class. - /// - /// This also propagates the changes through the e-graph, - /// so [`Analysis::make`] and [`Analysis::merge`] will get - /// called for other parts of the e-graph on rebuild. - pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) { - let id = self.find_mut(id); - let class = self.classes.get_mut(&id).unwrap(); - class.data = new_data; - self.analysis_pending.extend(class.parents.iter().cloned()); - N::modify(self, id) - } - /// Returns a more debug-able representation of the egraph. /// /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It @@ -983,104 +1035,46 @@ impl> EGraph { pub fn dump(&self) -> impl Debug + '_ { EGraphDump(self) } -} -impl> EGraph { - /// Panic if the given eclass doesn't contain the given patterns - /// - /// Useful for testing. - pub fn check_goals(&self, id: Id, goals: &[Pattern]) { - let (cost, best) = Extractor::new(self, AstSize).find_best(id); - println!("End ({}): {}", cost, best.pretty(80)); - - for (i, goal) in goals.iter().enumerate() { - println!("Trying to prove goal {}: {}", i, goal.pretty(40)); - let matches = goal.search_eclass(self, id); - if matches.is_none() { - let best = Extractor::new(self, AstSize).find_best(id).1; - panic!( - "Could not prove goal {}:\n\ - {}\n\ - Best thing found:\n\ - {}", - i, - goal.pretty(40), - best.pretty(40), - ); - } - } - } -} - -// All the rebuilding stuff -impl> EGraph { - #[inline(never)] - fn rebuild_classes(&mut self) -> usize { - let mut classes_by_op = std::mem::take(&mut self.classes_by_op); - classes_by_op.values_mut().for_each(|ids| ids.clear()); - - let mut trimmed = 0; - let uf = &mut self.unionfind; - - for class in self.classes.values_mut() { - let old_len = class.len(); - class - .nodes - .iter_mut() - .for_each(|n| n.update_children(|id| uf.find_mut(id))); - class.nodes.sort_unstable(); - class.nodes.dedup(); - - trimmed += old_len - class.nodes.len(); - - let mut add = |n: &L| { - classes_by_op - .entry(n.discriminant()) - .or_default() - .insert(class.id) - }; - - // we can go through the ops in order to dedup them, becaue we - // just sorted them - let mut nodes = class.nodes.iter(); - if let Some(mut prev) = nodes.next() { - add(prev); - for n in nodes { - if !prev.matches(n) { - add(n); - prev = n; - } + pub(crate) fn generate_class_nodes(&self) -> HashMap> { + let mut classes = HashMap::default(); + let find = |id| self.find(id); + for (id, node) in self.nodes() { + let id = find(id); + let node = node.clone().map_children(find); + match classes.get_mut(&id) { + None => { + classes.insert(id, vec![node]); } + Some(x) => x.push(node), } } - #[cfg(debug_assertions)] - for ids in classes_by_op.values_mut() { - let unique: HashSet = ids.iter().copied().collect(); - assert_eq!(ids.len(), unique.len()); + // define all the nodes, clustered by eclass + for class in classes.values_mut() { + class.sort_unstable(); + class.dedup(); } - - self.classes_by_op = classes_by_op; - trimmed + classes } +} +// All the rebuilding stuff +impl> EGraph { #[inline(never)] fn check_memo(&self) -> bool { let mut test_memo = HashMap::default(); - for (&id, class) in self.classes.iter() { - assert_eq!(class.id, id); - for node in &class.nodes { - if let Some(old) = test_memo.insert(node, id) { - assert_eq!( - self.find(old), - self.find(id), - "Found unexpected equivalence for {:?}\n{:?}\nvs\n{:?}", - node, - self[self.find(id)].nodes, - self[self.find(old)].nodes, - ); - } + for (id, node) in self.nodes.iter().enumerate() { + let node = node.clone().map_children(|x| self.unionfind.find(x)); + let id = self.unionfind.find(Id::from(id)); + if let Some(old) = test_memo.insert(node.clone(), id) { + assert_eq!( + self.find(old), + self.find(id), + "Found unexpected equivalence for {:?}", + node, + ); } } @@ -1088,7 +1082,7 @@ impl> EGraph { assert_eq!(e, self.find(e)); assert_eq!( Some(e), - self.memo.get(n).map(|id| self.find(*id)), + self.memo.get(&n).map(|id| self.find(*id)), "Entry for {:?} at {} in test_memo was incorrect", n, e @@ -1102,8 +1096,9 @@ impl> EGraph { fn process_unions(&mut self) -> usize { let mut n_unions = 0; - while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some((mut node, class)) = self.pending.pop() { + loop { + while let Some(class) = self.pending.pop() { + let mut node = self.nodes[usize::from(class)].clone(); node.update_children(|id| self.find_mut(id)); if let Some(memo_class) = self.memo.insert(node, class) { let did_something = self.perform_union( @@ -1116,21 +1111,12 @@ impl> EGraph { } } - while let Some((node, class_id)) = self.analysis_pending.pop() { - let class_id = self.find_mut(class_id); - let node_data = N::make(self, &node); - let class = self.classes.get_mut(&class_id).unwrap(); - - let did_merge = self.analysis.merge(&mut class.data, node_data); - if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().cloned()); - N::modify(self, class_id) - } + if !N::rebuild(Just::new(self), false) && self.pending.is_empty() { + break; } } assert!(self.pending.is_empty()); - assert!(self.analysis_pending.is_empty()); n_unions } @@ -1179,7 +1165,7 @@ impl> EGraph { let start = Instant::now(); let n_unions = self.process_unions(); - let trimmed_nodes = self.rebuild_classes(); + // let trimmed_nodes = self.rebuild_classes(); let elapsed = start.elapsed(); info!( @@ -1187,7 +1173,7 @@ impl> EGraph { "REBUILT! in {}.{:03}s\n", " Old: hc size {}, eclasses: {}\n", " New: hc size {}, eclasses: {}\n", - " unions: {}, trimmed nodes: {}" + " unions: {}" ), elapsed.as_secs(), elapsed.subsec_millis(), @@ -1196,21 +1182,12 @@ impl> EGraph { self.memo.len(), self.number_of_classes(), n_unions, - trimmed_nodes, ); debug_assert!(self.check_memo()); self.clean = true; n_unions } - - pub(crate) fn check_each_explain(&self, rules: &[&Rewrite]) -> bool { - if let Some(explain) = &self.explain { - explain.check_each_explain(rules) - } else { - panic!("Can't check explain when explanations are off"); - } - } } struct EGraphDump<'a, L: Language, N: Analysis>(&'a EGraph); @@ -1220,9 +1197,7 @@ impl<'a, L: Language, N: Analysis> Debug for EGraphDump<'a, L, N> { let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); ids.sort(); for id in ids { - let mut nodes = self.0[id].nodes.clone(); - nodes.sort(); - writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)? + writeln!(f, "{}: {:?}", id, self.0[id].data)? } Ok(()) } diff --git a/src/ematch_analysis.rs b/src/ematch_analysis.rs new file mode 100644 index 00000000..4a1fbba3 --- /dev/null +++ b/src/ematch_analysis.rs @@ -0,0 +1,177 @@ +use crate::legacy::{concat_vecs, HashMap, HashSet}; +use crate::{ + Analysis, AnalysisData, AstSize, EClass, EGraph, EGraphT, Extractor, Id, Justification, + Language, Pattern, RecExpr, Rewrite, Searcher, +}; +use log::trace; +use std::fmt::Display; +use std::mem; + +/// An [`EGraph`] that includes an [`EMatchingAnalysis`] +pub type EMGraph = EGraph)>; + +/// An [`EClass`] that includes an [`EMatchingAnalysis`] +pub type EMClass = EClass<(D, EMatchingData)>; + +/// An [`Analysis`] that supports backtracking e-matching +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct EMatchingAnalysis { + #[cfg_attr(feature = "serde-1", serde(skip))] + #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))] + pub(crate) classes_by_op: HashMap>, +} + +#[cfg(feature = "serde-1")] +fn default_classes_by_op() -> HashMap> { + HashMap::default() +} + +pub type EMatchingData = Vec; + +impl Default for EMatchingAnalysis { + fn default() -> Self { + EMatchingAnalysis { + classes_by_op: Default::default(), + } + } +} + +impl AnalysisData for EMatchingAnalysis { + type Data = EMatchingData; // EClass.nodes +} + +impl Analysis for EMatchingAnalysis { + fn make>(_: E, enode: &L) -> Self::Data { + vec![enode.clone()] + } + + fn merge>( + mut egraph: E, + new_root: Id, + _: Id, + other_data: Self::Data, + _: &[Id], + _: &Option, + ) { + concat_vecs(egraph.data_mut(new_root), other_data) + } + + fn rebuild>(mut egraph: E, will_repeat: bool) -> bool { + if will_repeat { + return false; + } + let mut classes_by_op = mem::take(&mut egraph.analysis_mut().classes_by_op); + classes_by_op.values_mut().for_each(|ids| ids.clear()); + let egraph = egraph.deref_mut(); + + let mut trimmed = 0; + let uf = &mut egraph.unionfind; + + for class in egraph.classes.values_mut() { + let id = class.id; + let data = E::proj_data_mut(&mut class.data); + let old_len = data.len(); + data.iter_mut() + .for_each(|n| n.update_children(|id| uf.find_mut(id))); + data.sort_unstable(); + data.dedup(); + + trimmed += old_len - data.len(); + + let mut add = |n: &L| { + classes_by_op + .entry(n.discriminant()) + .or_default() + .insert(id) + }; + + // we can go through the ops in order to dedup them, becaue we + // just sorted them + let mut nodes = data.iter(); + if let Some(mut prev) = nodes.next() { + add(prev); + for n in nodes { + if !prev.matches(n) { + add(n); + prev = n; + } + } + } + } + + #[cfg(debug_assertions)] + for ids in classes_by_op.values_mut() { + let unique: HashSet = ids.iter().copied().collect(); + assert_eq!(ids.len(), unique.len()); + } + + E::proj_mut(&mut egraph.analysis).classes_by_op = classes_by_op; + log::info!("trimmed nodes: {trimmed}"); + false + } +} + +impl> EMGraph { + pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { + if let Some(explain) = &mut self.explain { + explain.with_nodes(&self.nodes).check_each_explain(rules) + } else { + panic!("Can't check explain when explanations are off"); + } + } + + /// Checks whether two [`RecExpr`]s are equivalent. + /// Returns a list of id where both expression are represented. + /// In most cases, there will none or exactly one id. + /// + pub fn equivs(&self, expr1: &RecExpr, expr2: &RecExpr) -> Vec { + let pat1 = Pattern::from(expr1.as_ref()); + let pat2 = Pattern::from(expr2.as_ref()); + let matches1 = pat1.search(self); + trace!("Matches1: {:?}", matches1); + + let matches2 = pat2.search(self); + trace!("Matches2: {:?}", matches2); + + let mut equiv_eclasses = Vec::new(); + + for m1 in &matches1 { + for m2 in &matches2 { + if self.find(m1.eclass) == self.find(m2.eclass) { + equiv_eclasses.push(m1.eclass) + } + } + } + + equiv_eclasses + } + + /// Panic if the given eclass doesn't contain the given patterns + /// + /// Useful for testing. + pub fn check_goals(&self, id: Id, goals: &[Pattern]) + where + L: Display, + { + let (cost, best) = Extractor::new(self, AstSize).find_best(id); + println!("End ({}): {}", cost, best.pretty(80)); + + for (i, goal) in goals.iter().enumerate() { + println!("Trying to prove goal {}: {}", i, goal.pretty(40)); + let matches = goal.search_eclass(self, id); + if matches.is_none() { + let best = Extractor::new(self, AstSize).find_best(id).1; + panic!( + "Could not prove goal {}:\n\ + {}\n\ + Best thing found:\n\ + {}", + i, + goal.pretty(40), + best.pretty(40), + ); + } + } + } +} diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..06834943 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,12 +1,13 @@ use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, - Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, + util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, + PatternAst, RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; use symbolic_expressions::Sexp; @@ -38,8 +39,7 @@ struct Connection { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +struct ExplainNode { // neighbors includes parent connections neighbors: Vec, parent_connection: Connection, @@ -54,8 +54,15 @@ struct ExplainNode { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + #[cfg_attr( + feature = "serde-1", + serde(bound( + serialize = "L: serde::Serialize", + deserialize = "L: serde::Deserialize<'de>", + )) + )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. pub optimize_explanation_lengths: bool, @@ -69,6 +76,11 @@ pub struct Explain { shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } +pub(crate) struct ExplainNodes<'a, L: Language> { + explain: &'a mut Explain, + nodes: &'a [L], +} + #[derive(Default)] struct DistanceMemo { parent_distance: Vec<(Id, ProofCost)>, @@ -883,97 +895,6 @@ impl PartialOrd for HeapState { } impl Explain { - pub(crate) fn node(&self, node_id: Id) -> &L { - &self.explainfind[usize::from(node_id)].node - } - fn node_to_explanation( - &self, - node_id: Id, - cache: &mut NodeExplanationCache, - ) -> Rc> { - if let Some(existing) = cache.get(&node_id) { - existing.clone() - } else { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); - sofar - }); - let res = Rc::new(TreeTerm::new(node, children)); - cache.insert(node_id, res.clone()); - res - } - } - - pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.node_to_recexpr_internal(&mut res, node_id, &mut cache); - res - } - fn node_to_recexpr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_recexpr_internal(res, child, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(new_node); - } - - pub(crate) fn node_to_pattern( - &self, - node_id: Id, - substitutions: &HashMap, - ) -> (Pattern, Subst) { - let mut res = Default::default(); - let mut subst = Default::default(); - let mut cache = Default::default(); - self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache); - (Pattern::new(res), subst) - } - - fn node_to_pattern_internal( - &self, - res: &mut PatternAst, - node_id: Id, - var_substitutions: &HashMap, - subst: &mut Subst, - cache: &mut HashMap, - ) { - if let Some(existing) = var_substitutions.get(&node_id) { - let var = format!("?{}", node_id).parse().unwrap(); - res.add(ENodeOrVar::Var(var)); - subst.insert(var, *existing); - } else { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_pattern_internal(res, child, var_substitutions, subst, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(ENodeOrVar::ENode(new_node)); - } - } - - fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(self.node_to_flat_explanation(child)); - sofar - }); - FlatTerm::new(node, children) - } - fn make_rule_table<'a, N: Analysis>( rules: &[&'a Rewrite], ) -> HashMap> { @@ -983,52 +904,6 @@ impl Explain { } table } - - pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { - let rule_table = Explain::make_rule_table(rules); - for i in 0..self.explainfind.len() { - let explain_node = &self.explainfind[i]; - - // check that explanation reasons never form a cycle - let mut existance = i; - let mut seen_existance: HashSet = Default::default(); - loop { - seen_existance.insert(existance); - let next = usize::from(self.explainfind[existance].existance_node); - if existance == next { - break; - } - existance = next; - if seen_existance.contains(&existance) { - panic!("Cycle in existance!"); - } - } - - if explain_node.parent_connection.next != Id::from(i) { - let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); - let mut next_explanation = - self.node_to_flat_explanation(explain_node.parent_connection.next); - if let Justification::Rule(rule_name) = - &explain_node.parent_connection.justification - { - if let Some(rule) = rule_table.get(rule_name) { - if !explain_node.parent_connection.is_rewrite_forward { - std::mem::swap(&mut current_explanation, &mut next_explanation); - } - if !Explanation::check_rewrite( - ¤t_explanation, - &next_explanation, - rule, - ) { - return false; - } - } - } - } - } - true - } - pub fn new() -> Self { Explain { explainfind: vec![], @@ -1044,9 +919,8 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node.clone(), set); + self.uncanon_memo.insert(node, set); self.explainfind.push(ExplainNode { - node, neighbors: vec![], parent_connection: Connection { justification: Justification::Congruence, @@ -1119,7 +993,7 @@ impl Explain { new_rhs: bool, ) { if let Justification::Congruence = justification { - assert!(self.node(node1).matches(self.node(node2))); + // assert!(self.node(node1).matches(self.node(node2))); } if new_rhs { self.set_existance_reason(node2, node1) @@ -1155,7 +1029,6 @@ impl Explain { .push(other_pconnection); self.explainfind[usize::from(node1)].parent_connection = pconnection; } - pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; for node in &self.explainfind { @@ -1170,13 +1043,103 @@ impl Explain { equalities } - pub(crate) fn populate_enodes>(&self, mut egraph: EGraph) -> EGraph { - for i in 0..self.explainfind.len() { - let node = &self.explainfind[i]; - egraph.add(node.node.clone()); + pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { + ExplainNodes { + explain: self, + nodes, } + } +} + +impl<'a, L: Language> Deref for ExplainNodes<'a, L> { + type Target = Explain; + + fn deref(&self) -> &Self::Target { + self.explain + } +} + +impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.explain + } +} + +impl<'x, L: Language> ExplainNodes<'x, L> { + pub(crate) fn node(&self, node_id: Id) -> &L { + &self.nodes[usize::from(node_id)] + } + fn node_to_explanation( + &self, + node_id: Id, + cache: &mut NodeExplanationCache, + ) -> Rc> { + if let Some(existing) = cache.get(&node_id) { + existing.clone() + } else { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar + }); + let res = Rc::new(TreeTerm::new(node, children)); + cache.insert(node_id, res.clone()); + res + } + } + + fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(self.node_to_flat_explanation(child)); + sofar + }); + FlatTerm::new(node, children) + } + + pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { + let rule_table = Explain::make_rule_table(rules); + for i in 0..self.explainfind.len() { + let explain_node = &self.explainfind[i]; + + // check that explanation reasons never form a cycle + let mut existance = i; + let mut seen_existance: HashSet = Default::default(); + loop { + seen_existance.insert(existance); + let next = usize::from(self.explainfind[existance].existance_node); + if existance == next { + break; + } + existance = next; + if seen_existance.contains(&existance) { + panic!("Cycle in existance!"); + } + } - egraph + if explain_node.parent_connection.next != Id::from(i) { + let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); + let mut next_explanation = + self.node_to_flat_explanation(explain_node.parent_connection.next); + if let Justification::Rule(rule_name) = + &explain_node.parent_connection.justification + { + if let Some(rule) = rule_table.get(rule_name) { + if !explain_node.parent_connection.is_rewrite_forward { + std::mem::swap(&mut current_explanation, &mut next_explanation); + } + if !Explanation::check_rewrite( + ¤t_explanation, + &next_explanation, + rule, + ) { + return false; + } + } + } + } + } + true } pub(crate) fn explain_equivalence>( @@ -1184,7 +1147,7 @@ impl Explain { left: Id, right: Id, unionfind: &mut UnionFind, - classes: &HashMap>, + classes: &HashMap>, ) -> Explanation { if self.optimize_explanation_lengths { self.calculate_shortest_explanations::(left, right, classes, unionfind); @@ -1328,7 +1291,7 @@ impl Explain { let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; - existance_node.node.for_each(|child| { + self.node(existance).for_each(|child| { if found { return; } @@ -1627,7 +1590,7 @@ impl Explain { fn find_congruence_neighbors>( &self, - classes: &HashMap>, + classes: &HashMap>, congruence_neighbors: &mut [Vec], unionfind: &UnionFind, ) { @@ -1673,7 +1636,7 @@ impl Explain { pub fn get_num_congr>( &self, - classes: &HashMap>, + classes: &HashMap>, unionfind: &UnionFind, ) -> usize { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; @@ -1890,7 +1853,7 @@ impl Explain { fn calculate_common_ancestor>( &self, - classes: &HashMap>, + classes: &HashMap>, congruence_neighbors: &[Vec], ) -> HashMap<(Id, Id), Id> { let mut common_ancestor_queries = HashMap::default(); @@ -1960,7 +1923,7 @@ impl Explain { &mut self, start: Id, end: Id, - classes: &HashMap>, + classes: &HashMap>, unionfind: &UnionFind, ) { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; @@ -2092,7 +2055,7 @@ mod tests { #[test] fn simple_explain_union_trusted() { - use crate::SymbolLang; + use crate::{EGraph, SymbolLang}; crate::init_logger(); let mut egraph = EGraph::new(()).with_explanations_enabled(); diff --git a/src/extract.rs b/src/extract.rs index 74fbdc60..75c8863b 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -2,7 +2,7 @@ use std::cmp::Ordering; use std::fmt::Debug; use crate::util::HashMap; -use crate::{Analysis, EClass, EGraph, Id, Language, RecExpr}; +use crate::{Analysis, EGraph, Id, Language, RecExpr}; /** Extracting a single [`RecExpr`] from an [`EGraph`]. @@ -215,7 +215,7 @@ where egraph, cost_function, }; - extractor.find_costs(); + extractor.find_costs(egraph.generate_class_nodes()); extractor } @@ -251,20 +251,20 @@ where } } - fn find_costs(&mut self) { + fn find_costs(&mut self, class_nodes: HashMap>) { let mut did_something = true; while did_something { did_something = false; - for class in self.egraph.classes() { - let pass = self.make_pass(class); - match (self.costs.get(&class.id), pass) { + for (&id, nodes) in &class_nodes { + let pass = self.make_pass(id, nodes); + match (self.costs.get(&id), pass) { (None, Some(new)) => { - self.costs.insert(class.id, new); + self.costs.insert(id, new); did_something = true; } (Some(old), Some(new)) if new.0 < old.0 => { - self.costs.insert(class.id, new); + self.costs.insert(id, new); did_something = true; } _ => (), @@ -272,23 +272,19 @@ where } } - for class in self.egraph.classes() { - if !self.costs.contains_key(&class.id) { - log::warn!( - "Failed to compute cost for eclass {}: {:?}", - class.id, - class.nodes - ) + for (&id, nodes) in &class_nodes { + if !self.costs.contains_key(&id) { + log::warn!("Failed to compute cost for eclass {}: {:?}", id, nodes) } } } - fn make_pass(&mut self, eclass: &EClass) -> Option<(CF::Cost, L)> { - let (cost, node) = eclass + fn make_pass(&mut self, id: Id, nodes: &[L]) -> Option<(CF::Cost, L)> { + let (cost, node) = nodes .iter() .map(|n| (self.node_total_cost(n), n)) .min_by(|a, b| cmp(&a.0, &b.0)) - .unwrap_or_else(|| panic!("Can't extract, eclass is empty: {:#?}", eclass)); + .unwrap_or_else(|| panic!("Can't extract, eclass is empty: {:?}", id)); cost.map(|c| (c, node.clone())) } } diff --git a/src/generic_analysis.rs b/src/generic_analysis.rs new file mode 100644 index 00000000..20490f96 --- /dev/null +++ b/src/generic_analysis.rs @@ -0,0 +1,316 @@ +use crate::{EGraph, Id, Justification, Language}; +use std::fmt::Debug; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; + +pub trait Extractor: 'static { + type Out; + + fn proj(x: &X) -> &Self::Out; + + fn proj_mut(x: &mut X) -> &mut Self::Out; +} + +/// An [`EGraph`] that is focusing on a specific sub [`Analysis`] +/// +/// It implements [`DerefMut`] to extract the underlying [`EGraph`] and includes methods +/// to convert the full [`Analysis`] from the [`EGraph`] to the specific sub [`Analysis`] +pub trait EGraphT: DerefMut> + Sized { + /// The full [`Analysis`] of the [`EGraph`] + type A: Analysis; + + /// The sub [`Analysis`] being focused on + type N: AnalysisData; + type E: Extractor; + + type ED: Extractor<>::Data, Out = >::Data>; + fn proj(x: &Self::A) -> &Self::N { + Self::E::proj(x) + } + fn proj_mut(x: &mut Self::A) -> &mut Self::N { + Self::E::proj_mut(x) + } + + fn proj_data(x: &>::Data) -> &>::Data { + Self::ED::proj(x) + } + + fn proj_data_mut( + x: &mut >::Data, + ) -> &mut >::Data { + Self::ED::proj_mut(x) + } + + fn analysis<'a>(&'a self) -> &'a Self::N + where + L: 'a, + { + Self::proj(&self.analysis) + } + + fn analysis_mut<'a>(&'a mut self) -> &'a mut Self::N + where + L: 'a, + { + Self::proj_mut(&mut self.analysis) + } + + fn data<'a>(&'a self, id: Id) -> &'a >::Data + where + L: 'a, + { + Self::proj_data(&self[id].data) + } + + fn data_mut<'a>(&'a mut self, id: Id) -> &'a mut >::Data + where + L: 'a, + { + Self::proj_data_mut(&mut self[id].data) + } + + fn shift( + &mut self, + ) -> EGraphC<'_, EGraph, ExtrCompose, ExtrCompose> { + EGraphC::new(&mut **self) + } +} + +pub struct EGraphC<'a, T, E, ED>(&'a mut T, PhantomData<(E, ED)>); + +impl<'a, T, E, ED> Deref for EGraphC<'a, T, E, ED> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T, E, ED> DerefMut for EGraphC<'a, T, E, ED> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a, L: Language, N: Analysis, E: Extractor, ED> EGraphT + for EGraphC<'a, EGraph, E, ED> +where + E::Out: AnalysisData, + ED: Extractor>::Data>, +{ + type A = N; + type N = E::Out; + + type E = E; + + type ED = ED; +} + +impl<'a, T, E, ED> EGraphC<'a, T, E, ED> { + pub fn new(t: &'a mut T) -> Self { + EGraphC(t, PhantomData) + } +} + +pub(crate) struct ExtrId; + +impl Extractor for ExtrId { + type Out = X; + + fn proj(x: &X) -> &Self::Out { + x + } + + fn proj_mut(x: &mut X) -> &mut Self::Out { + x + } +} + +pub(crate) type Just<'a, T> = EGraphC<'a, T, ExtrId, ExtrId>; + +pub struct ExtrCompose(E, F); + +impl, F: Extractor> Extractor for ExtrCompose { + type Out = F::Out; + + fn proj(x: &X) -> &Self::Out { + F::proj(E::proj(x)) + } + + fn proj_mut(x: &mut X) -> &mut Self::Out { + F::proj_mut(E::proj_mut(x)) + } +} + +pub(crate) struct Extr0; + +impl Extractor<(X, Y)> for Extr0 { + type Out = X; + + fn proj(x: &(X, Y)) -> &Self::Out { + &x.0 + } + + fn proj_mut(x: &mut (X, Y)) -> &mut Self::Out { + &mut x.0 + } +} + +pub(crate) struct Extr1; + +impl Extractor<(X, Y)> for Extr1 { + type Out = Y; + + fn proj(x: &(X, Y)) -> &Self::Out { + &x.1 + } + + fn proj_mut(x: &mut (X, Y)) -> &mut Self::Out { + &mut x.1 + } +} + +/// Contains the `Data` associated type for [`Analysis`] and [`LatticeAnalysis`] +pub trait AnalysisData: Sized { + /// The per-[`EClass`] data for this analysis. + type Data: Debug; +} + +/** Arbitrary data associated with an [`EClass`]. + +[`Analysis`] allows you to associate arbitrary data with each eclass. + +It is also useful for providing hooks to stay consistent with various [`EGraph`] operations + +If you don't care about [`Analysis`], `()` implements it trivally, just use that. +If you want to use multiple [`Analysis`], it is also implemented for pairs + **/ +#[allow(unused_variables)] +pub trait Analysis: AnalysisData { + /// Makes a new [`Analysis`] data for a given e-node. + /// + /// Note the mutable `egraph` parameter: this is needed for some + /// advanced use cases, but most use cases will not need to mutate + /// the e-graph in any way. + /// It is **not** `make`'s responsiblity to insert the e-node; + /// the e-node is "being inserted" when this function is called. + /// Doing so will create an infinite loop. + fn make>(egraph: E, enode: &L) -> Self::Data; + + /// An optional hook that allows the modifications involving the `Id` of newly created nodes + fn post_make>(egraph: E, id: Id) {} + + /// An optional hook that allows inspection before a [`union`] occurs. + /// When explanations are enabled, it gives two ids that represent the two particular terms being unioned, not the canonical ids for the two eclasses. + /// It also gives a justification for the union when explanations are enabled. + /// + /// By default it does nothing. + /// + /// `pre_union` is called _a lot_, so doing anything significant + /// (like printing) will cause things to slow down. + /// + /// [`union`]: EGraph::union() + fn pre_union>( + egraph: E, + id1: Id, + id2: Id, + justification: &Option, + ) { + } + + /// Hook called just after two [`EClass`]es merge. + /// It also gives a justification for the union when explanations are enabled. + /// + /// `new_root` contains the `Id` of the [`EClass`] that was chosen as the parent + /// and can be used to lookup that [`EClass`] in the [`EGraph`] + /// + /// `other_id`, `other_data`, and `other_parents` represent the other [`EClass`] + /// since it is no longer available in the [`EGraph`] + /// + /// This should update `egraph.node(new_root)` to correspond to the merged analysis data. + fn merge>( + egraph: E, + new_root: Id, + other_id: Id, + other_data: Self::Data, + other_parents: &[Id], + justification: &Option, + ) { + } + + /// Hook called when rebuilding the [`EGraph`] + /// + /// Should return `true` if it did any work causes this rebuild to restart after it finishes + /// + /// If `will_repeat` is true this function will necessarily be called again in this rebuild cycle + fn rebuild>(egraph: E, will_repeat: bool) -> bool { + false + } +} + +impl AnalysisData for () { + type Data = (); +} + +impl Analysis for () { + fn make>(_: E, _: &L) -> Self::Data {} +} + +impl, N1: AnalysisData> AnalysisData for (N0, N1) { + type Data = (N0::Data, N1::Data); +} + +impl, N1: Analysis> Analysis for (N0, N1) { + fn make>(mut egraph: E, enode: &L) -> Self::Data { + ( + N0::make(egraph.shift::(), enode), + N1::make(egraph.shift::(), enode), + ) + } + + fn post_make>(mut egraph: E, id: Id) { + N0::post_make(egraph.shift::(), id); + N1::post_make(egraph.shift::(), id); + } + + fn pre_union>( + mut egraph: E, + id1: Id, + id2: Id, + justification: &Option, + ) { + N0::pre_union(egraph.shift::(), id1, id2, justification); + N1::pre_union(egraph.shift::(), id1, id2, justification); + } + + fn merge>( + mut egraph: E, + new_root: Id, + other_id: Id, + other_data: Self::Data, + other_parents: &[Id], + justification: &Option, + ) { + N0::merge( + egraph.shift::(), + new_root, + other_id, + other_data.0, + other_parents, + justification, + ); + N1::merge( + egraph.shift::(), + new_root, + other_id, + other_data.1, + other_parents, + justification, + ); + } + + fn rebuild>(mut egraph: E, will_repeat: bool) -> bool { + let x = N0::rebuild(egraph.shift::(), will_repeat); + N1::rebuild(egraph.shift::(), will_repeat || x) || x + } +} diff --git a/src/language.rs b/src/language.rs index 6414c63a..e2cee1a7 100644 --- a/src/language.rs +++ b/src/language.rs @@ -1,5 +1,5 @@ -use std::ops::{BitOr, Index, IndexMut}; -use std::{cmp::Ordering, convert::TryFrom}; +use std::convert::TryFrom; +use std::ops::{Index, IndexMut}; use std::{ convert::Infallible, fmt::{self, Debug, Display}, @@ -161,11 +161,11 @@ pub trait Language: Debug + Clone + Eq + Ord + Hash { /// You could use this method to perform an "ad-hoc" extraction from the e-graph, /// where you already know which node you want pick for each class: /// ``` - /// # use egg::*; + /// # use egg::legacy::*; /// let mut egraph = EGraph::::default(); /// let expr = "(foo (bar1 (bar2 (bar3 baz))))".parse().unwrap(); /// let root = egraph.add_expr(&expr); - /// let get_first_enode = |id| egraph[id].nodes[0].clone(); + /// let get_first_enode = |id| egraph[id].nodes()[0].clone(); /// let expr2 = get_first_enode(root).build_recexpr(get_first_enode); /// assert_eq!(expr, expr2) /// ``` @@ -596,232 +596,6 @@ impl FromStr for RecExpr { } } -/// Result of [`Analysis::merge`] indicating which of the inputs -/// are different from the merged result. -/// -/// The fields correspond to whether the initial `a` and `b` inputs to [`Analysis::merge`] -/// were different from the final merged value. -/// -/// In both cases the result may be conservative -- they may indicate `true` even -/// when there is no difference between the input and the result. -/// -/// `DidMerge`s can be "or"ed together using the `|` operator. -/// This can be useful for composing analyses. -pub struct DidMerge(pub bool, pub bool); - -impl BitOr for DidMerge { - type Output = DidMerge; - - fn bitor(mut self, rhs: Self) -> Self::Output { - self.0 |= rhs.0; - self.1 |= rhs.1; - self - } -} - -/** Arbitrary data associated with an [`EClass`]. - -`egg` allows you to associate arbitrary data with each eclass. -The [`Analysis`] allows that data to behave well even across eclasses merges. - -[`Analysis`] can prove useful in many situtations. -One common one is constant folding, a kind of partial evaluation. -In that case, the metadata is basically `Option`, storing -the cheapest constant expression (if any) that's equivalent to the -enodes in this eclass. -See the test files [`math.rs`] and [`prop.rs`] for more complex -examples on this usage of [`Analysis`]. - -If you don't care about [`Analysis`], `()` implements it trivally, -just use that. - -# Example - -``` -use egg::{*, rewrite as rw}; - -define_language! { - enum SimpleMath { - "+" = Add([Id; 2]), - "*" = Mul([Id; 2]), - Num(i32), - Symbol(Symbol), - } -} - -// in this case, our analysis itself doesn't require any data, so we can just -// use a unit struct and derive Default -#[derive(Default)] -struct ConstantFolding; -impl Analysis for ConstantFolding { - type Data = Option; - - fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { - egg::merge_max(to, from) - } - - fn make(egraph: &mut EGraph, enode: &SimpleMath) -> Self::Data { - let x = |i: &Id| egraph[*i].data; - match enode { - SimpleMath::Num(n) => Some(*n), - SimpleMath::Add([a, b]) => Some(x(a)? + x(b)?), - SimpleMath::Mul([a, b]) => Some(x(a)? * x(b)?), - _ => None, - } - } - - fn modify(egraph: &mut EGraph, id: Id) { - if let Some(i) = egraph[id].data { - let added = egraph.add(SimpleMath::Num(i)); - egraph.union(id, added); - } - } -} - -let rules = &[ - rw!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), - rw!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), - - rw!("add-0"; "(+ ?a 0)" => "?a"), - rw!("mul-0"; "(* ?a 0)" => "0"), - rw!("mul-1"; "(* ?a 1)" => "?a"), -]; - -let expr = "(+ 0 (* (+ 4 -3) foo))".parse().unwrap(); -let mut runner = Runner::::default().with_expr(&expr).run(rules); -let just_foo = runner.egraph.add_expr(&"foo".parse().unwrap()); -assert_eq!(runner.egraph.find(runner.roots[0]), runner.egraph.find(just_foo)); -``` - -[`math.rs`]: https://github.com/egraphs-good/egg/blob/main/tests/math.rs -[`prop.rs`]: https://github.com/egraphs-good/egg/blob/main/tests/prop.rs -*/ -pub trait Analysis: Sized { - /// The per-[`EClass`] data for this analysis. - type Data: Debug; - - /// Makes a new [`Analysis`] data for a given e-node. - /// - /// Note the mutable `egraph` parameter: this is needed for some - /// advanced use cases, but most use cases will not need to mutate - /// the e-graph in any way. - /// It is **not** `make`'s responsiblity to insert the e-node; - /// the e-node is "being inserted" when this function is called. - /// Doing so will create an infinite loop. - fn make(egraph: &mut EGraph, enode: &L) -> Self::Data; - - /// An optional hook that allows inspection before a [`union`] occurs. - /// When explanations are enabled, it gives two ids that represent the two particular terms being unioned, not the canonical ids for the two eclasses. - /// It also gives a justification for the union when explanations are enabled. - /// - /// By default it does nothing. - /// - /// `pre_union` is called _a lot_, so doing anything significant - /// (like printing) will cause things to slow down. - /// - /// [`union`]: EGraph::union() - #[allow(unused_variables)] - fn pre_union( - egraph: &EGraph, - id1: Id, - id2: Id, - justification: &Option, - ) { - } - - /// Defines how to merge two `Data`s when their containing - /// [`EClass`]es merge. - /// - /// This should update `a` to correspond to the merged analysis - /// data. - /// - /// The result is a `DidMerge(a_merged, b_merged)` indicating whether - /// the merged result is different from `a` and `b` respectively. - /// - /// Since `merge` can modify `a`, let `a0`/`a1` be the value of `a` - /// before/after the call to `merge`, respectively. - /// - /// If `a0 != a1` the result must have `a_merged == true`. This may be - /// conservative -- it may be `true` even if even if `a0 == a1`. - /// - /// If `b != a1` the result must have `b_merged == true`. This may be - /// conservative -- it may be `true` even if even if `b == a1`. - /// - /// This function may modify the [`Analysis`], which can be useful as a way - /// to store information for the [`Analysis::modify`] hook to process, since - /// `modify` has access to the e-graph. - fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge; - - /// A hook that allows the modification of the - /// [`EGraph`]. - /// - /// By default this does nothing. - /// - /// This function is called immediately following - /// `Analysis::merge` when unions are performed. - #[allow(unused_variables)] - fn modify(egraph: &mut EGraph, id: Id) {} -} - -impl Analysis for () { - type Data = (); - fn make(_egraph: &mut EGraph, _enode: &L) -> Self::Data {} - fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge { - DidMerge(false, false) - } -} - -/// A utility for implementing [`Analysis::merge`] -/// when the `Data` type has a total ordering. -/// This will take the maximum of the two values. -pub fn merge_max(to: &mut T, from: T) -> DidMerge { - let cmp = (*to).cmp(&from); - match cmp { - Ordering::Less => { - *to = from; - DidMerge(true, false) - } - Ordering::Equal => DidMerge(false, false), - Ordering::Greater => DidMerge(false, true), - } -} - -/// A utility for implementing [`Analysis::merge`] -/// when the `Data` type has a total ordering. -/// This will take the minimum of the two values. -pub fn merge_min(to: &mut T, from: T) -> DidMerge { - let cmp = (*to).cmp(&from); - match cmp { - Ordering::Less => DidMerge(false, true), - Ordering::Equal => DidMerge(false, false), - Ordering::Greater => { - *to = from; - DidMerge(true, false) - } - } -} - -/// A utility for implementing [`Analysis::merge`] -/// when the `Data` type is an [`Option`]. -/// -/// Always take a `Some` over a `None` -/// and calls the given function to merge two `Some`s. -pub fn merge_option( - to: &mut Option, - from: Option, - merge_fn: impl FnOnce(&mut T, T) -> DidMerge, -) -> DidMerge { - match (to.as_mut(), from) { - (None, None) => DidMerge(false, false), - (None, from @ Some(_)) => { - *to = from; - DidMerge(true, false) - } - (Some(_), None) => DidMerge(false, true), - (Some(a), Some(b)) => merge_fn(a, b), - } -} - /// A simple language used for testing. #[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] diff --git a/src/lattice_analysis.rs b/src/lattice_analysis.rs new file mode 100644 index 00000000..08c60701 --- /dev/null +++ b/src/lattice_analysis.rs @@ -0,0 +1,352 @@ +use crate::generic_analysis::{Extr0, ExtrId, Extractor, Just}; +use crate::legacy::UniqueQueue; +use crate::{Analysis, AnalysisData, EGraph, EGraphT, Id, Justification, Language}; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::mem; +use std::ops::BitOr; + +/// Result of [`Analysis::merge`] indicating which of the inputs +/// are different from the merged result. +/// +/// The fields correspond to whether the initial `a` and `b` inputs to [`Analysis::merge`] +/// were different from the final merged value. +/// +/// In both cases the result may be conservative -- they may indicate `true` even +/// when there is no difference between the input and the result. +/// +/// `DidMerge`s can be "or"ed together using the `|` operator. +/// This can be useful for composing analyses. +pub struct DidMerge(pub bool, pub bool); + +impl BitOr for DidMerge { + type Output = DidMerge; + + fn bitor(mut self, rhs: Self) -> Self::Output { + self.0 |= rhs.0; + self.1 |= rhs.1; + self + } +} + +/** Arbitrary data associated with an [`EClass`]. + +`egg` allows you to associate arbitrary data with each eclass. +The [`Analysis`] allows that data to behave well even across eclasses merges. + +[`Analysis`] can prove useful in many situtations. +One common one is constant folding, a kind of partial evaluation. +In that case, the metadata is basically `Option`, storing +the cheapest constant expression (if any) that's equivalent to the +enodes in this eclass. +See the test files [`math.rs`] and [`prop.rs`] for more complex +examples on this usage of [`Analysis`]. + +If you don't care about [`Analysis`], `()` implements it trivally, +just use that. + +# Example + +``` +use egg::legacy::{*, rewrite as rw}; + +define_language! { + enum SimpleMath { + "+" = Add([Id; 2]), + "*" = Mul([Id; 2]), + Num(i32), + Symbol(Symbol), + } +} + +impl AnalysisData for ConstantFolding { + type Data = Option; +} + +// in this case, our analysis itself doesn't require any data, so we can just +// use a unit struct and derive Default +#[derive(Default)] +struct ConstantFolding; +impl Analysis for ConstantFolding { + + fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + egg::merge_max(to, from) + } + + fn make>(egraph: E, enode: &SimpleMath) -> Self::Data { + let x = |i: &Id| *egraph.data(*i); + match enode { + SimpleMath::Num(n) => Some(*n), + SimpleMath::Add([a, b]) => Some(x(a)? + x(b)?), + SimpleMath::Mul([a, b]) => Some(x(a)? * x(b)?), + _ => None, + } + } + + fn modify>(mut egraph: E, id: Id) { + if let Some(i) = *egraph.data(id) { + let added = egraph.add(SimpleMath::Num(i)); + egraph.union(id, added); + } + } +} + +let rules = &[ + rw!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rw!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), + + rw!("add-0"; "(+ ?a 0)" => "?a"), + rw!("mul-0"; "(* ?a 0)" => "0"), + rw!("mul-1"; "(* ?a 1)" => "?a"), +]; + +let expr = "(+ 0 (* (+ 4 -3) foo))".parse().unwrap(); +let mut runner = Runner::::default().with_expr(&expr).run(rules); +let just_foo = runner.egraph.add_expr(&"foo".parse().unwrap()); +assert_eq!(runner.egraph.find(runner.roots[0]), runner.egraph.find(just_foo)); +``` + +[`math.rs`]: https://github.com/egraphs-good/egg/blob/main/tests/math.rs +[`prop.rs`]: https://github.com/egraphs-good/egg/blob/main/tests/prop.rs + */ +pub trait LatticeAnalysis: AnalysisData { + /// Makes a new [`Analysis`] data for a given e-node. + /// + /// Note the mutable `egraph` parameter: this is needed for some + /// advanced use cases, but most use cases will not need to mutate + /// the e-graph in any way. + /// It is **not** `make`'s responsiblity to insert the e-node; + /// the e-node is "being inserted" when this function is called. + /// Doing so will create an infinite loop. + fn make>(egraph: E, enode: &L) -> Self::Data; + + /// An optional hook that allows inspection before a [`union`] occurs. + /// When explanations are enabled, it gives two ids that represent the two particular terms being unioned, not the canonical ids for the two eclasses. + /// It also gives a justification for the union when explanations are enabled. + /// + /// By default it does nothing. + /// + /// `pre_union` is called _a lot_, so doing anything significant + /// (like printing) will cause things to slow down. + /// + /// [`union`]: EGraph::union() + #[allow(unused_variables)] + fn pre_union>( + egraph: E, + id1: Id, + id2: Id, + justification: &Option, + ) { + } + + /// Defines how to merge two `Data`s when their containing + /// [`EClass`]es merge. + /// + /// This should update `a` to correspond to the merged analysis + /// data. + /// + /// The result is a `DidMerge(a_merged, b_merged)` indicating whether + /// the merged result is different from `a` and `b` respectively. + /// + /// Since `merge` can modify `a`, let `a0`/`a1` be the value of `a` + /// before/after the call to `merge`, respectively. + /// + /// If `a0 != a1` the result must have `a_merged == true`. This may be + /// conservative -- it may be `true` even if even if `a0 == a1`. + /// + /// If `b != a1` the result must have `b_merged == true`. This may be + /// conservative -- it may be `true` even if even if `b == a1`. + /// + /// This function may modify the [`Analysis`], which can be useful as a way + /// to store information for the [`crate::generic_analysis::Analysis::modify`] hook to process, since + /// `modify` has access to the e-graph. + fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge; + + /// A hook that allows the modification of the + /// [`EGraph`]. + /// + /// By default this does nothing. + /// + /// This function is called immediately following + /// `Analysis::merge` when unions are performed. + #[allow(unused_variables)] + fn modify>(egraph: E, id: Id) {} +} + +impl LatticeAnalysis for () { + fn make>(_: E, _: &L) -> Self::Data {} + + fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge { + DidMerge(false, false) + } +} + +/// A utility for implementing [`Analysis::merge`] +/// when the `Data` type has a total ordering. +/// This will take the maximum of the two values. +pub fn merge_max(to: &mut T, from: T) -> DidMerge { + let cmp = (*to).cmp(&from); + match cmp { + Ordering::Less => { + *to = from; + DidMerge(true, false) + } + Ordering::Equal => DidMerge(false, false), + Ordering::Greater => DidMerge(false, true), + } +} + +/// A utility for implementing [`Analysis::merge`] +/// when the `Data` type has a total ordering. +/// This will take the minimum of the two values. +pub fn merge_min(to: &mut T, from: T) -> DidMerge { + let cmp = (*to).cmp(&from); + match cmp { + Ordering::Less => DidMerge(false, true), + Ordering::Equal => DidMerge(false, false), + Ordering::Greater => { + *to = from; + DidMerge(true, false) + } + } +} + +/// A utility for implementing [`Analysis::merge`] +/// when the `Data` type is an [`Option`]. +/// +/// Always take a `Some` over a `None` +/// and calls the given function to merge two `Some`s. +pub fn merge_option( + to: &mut Option, + from: Option, + merge_fn: impl FnOnce(&mut T, T) -> DidMerge, +) -> DidMerge { + match (to.as_mut(), from) { + (None, None) => DidMerge(false, false), + (None, from @ Some(_)) => { + *to = from; + DidMerge(true, false) + } + (Some(_), None) => DidMerge(false, true), + (Some(a), Some(b)) => merge_fn(a, b), + } +} + +/// Wrapper that converts a [`LatticeAnalysis`] into an [`Analysis`] +#[derive(Clone, Default)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct WrapLatticeAnalysis { + pub analysis: N, + analysis_pending: UniqueQueue, +} + +impl Debug for WrapLatticeAnalysis { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.analysis.fmt(f) + } +} + +impl> AnalysisData for WrapLatticeAnalysis { + type Data = N::Data; +} + +struct ExtrA; + +impl Extractor> for ExtrA { + type Out = N; + + fn proj(x: &WrapLatticeAnalysis) -> &Self::Out { + &x.analysis + } + + fn proj_mut(x: &mut WrapLatticeAnalysis) -> &mut Self::Out { + &mut x.analysis + } +} + +impl> Analysis for WrapLatticeAnalysis { + fn make>(mut egraph: E, enode: &L) -> Self::Data { + N::make(egraph.shift::(), enode) + } + + fn post_make>(mut egraph: E, id: Id) { + N::modify(egraph.shift::(), id) + } + + fn pre_union>( + mut egraph: E, + id1: Id, + id2: Id, + justification: &Option, + ) { + N::pre_union(egraph.shift::(), id1, id2, justification) + } + + fn merge>( + mut egraph: E, + new_root: Id, + _: Id, + other_data: Self::Data, + other_parents: &[Id], + _: &Option, + ) { + let (class, analysis, _) = egraph.class_and_rest(new_root); + let analysis = E::proj_mut(analysis); + let data = E::proj_data_mut(&mut class.data); + let did_merge = analysis.analysis.merge(data, other_data); + if did_merge.0 { + analysis + .analysis_pending + .extend(class.parents.iter().copied()); + } + if did_merge.1 { + analysis + .analysis_pending + .extend(other_parents.iter().copied()); + } + N::modify(egraph.shift::(), new_root) + } + + fn rebuild>(mut egraph: E, _: bool) -> bool { + let mut analysis_pending = mem::take(&mut egraph.analysis_mut().analysis_pending); + if analysis_pending.is_empty() { + return false; + } + + while let Some(class_id) = analysis_pending.pop() { + let node = egraph.id_to_node(class_id).clone(); + let node_data = N::make(egraph.shift::(), &node); + let (class, analysis, _) = egraph.class_and_rest(class_id); + let analysis = E::proj_mut(analysis); + let did_merge = analysis + .analysis + .merge(E::proj_data_mut(&mut class.data), node_data); + if did_merge.0 { + analysis + .analysis_pending + .extend(class.parents.iter().copied()); + N::modify(egraph.shift::(), class_id) + } + } + true + } +} + +impl, O: Analysis> EGraph, O)> { + /// Update the analysis data of an e-class. + /// + /// This also propagates the changes through the e-graph, + /// so [`Analysis::make`] and [`Analysis::merge`] will get + /// called for other parts of the e-graph on rebuild. + pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) { + let (class, analysis, _) = self.class_and_rest(id); + class.data.0 = new_data; + analysis.0.analysis_pending.extend(class.parents()); + N::modify( + Just::new(self) + .shift::() + .shift::(), + id, + ) + } +} diff --git a/src/legacy.rs b/src/legacy.rs new file mode 100644 index 00000000..1c10ecf8 --- /dev/null +++ b/src/legacy.rs @@ -0,0 +1,8 @@ +pub use crate::*; +pub use generic_analysis::Analysis as RawAnalysis; +pub use LatticeAnalysis as Analysis; + +pub type EGraph = EMGraph>; +pub type EClass = EMClass>; +pub type Runner = run::Runner, I>; +pub type Rewrite = rewrite::Rewrite>; diff --git a/src/lib.rs b/src/lib.rs index 5a293a58..6a9cc15a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,9 +40,12 @@ pub mod tutorials; mod dot; mod eclass; mod egraph; +mod ematch_analysis; mod explain; mod extract; +mod generic_analysis; mod language; +mod lattice_analysis; #[cfg(feature = "lp")] mod lp_extract; mod machine; @@ -54,6 +57,9 @@ mod subst; mod unionfind; mod util; +/// Type aliases that are more similar to the types in older versions of `egg` +pub mod legacy; + /// A key to identify [`EClass`]es within an /// [`EGraph`]. #[derive(Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] @@ -91,12 +97,15 @@ pub use { dot::Dot, eclass::EClass, egraph::EGraph, + ematch_analysis::*, explain::{ Explanation, FlatExplanation, FlatTerm, Justification, TreeExplanation, TreeTerm, UnionEqualities, }, extract::*, + generic_analysis::{Analysis, AnalysisData, EGraphT}, language::*, + lattice_analysis::*, multipattern::*, pattern::{ENodeOrVar, Pattern, PatternAst, SearchMatches}, rewrite::{Applier, Condition, ConditionEqual, ConditionalApplier, Rewrite, Searcher}, diff --git a/src/machine.rs b/src/machine.rs index 30327e22..ed1ddbf7 100644 --- a/src/machine.rs +++ b/src/machine.rs @@ -35,48 +35,43 @@ enum ENodeOrReg { #[inline(always)] fn for_each_matching_node( - eclass: &EClass, + eclass: &EMClass, node: &L, mut f: impl FnMut(&L) -> Result, ) -> Result where L: Language, { - if eclass.nodes.len() < 50 { - eclass - .nodes - .iter() - .filter(|n| node.matches(n)) - .try_for_each(f) + if eclass.len() < 50 { + eclass.iter().filter(|n| node.matches(n)).try_for_each(f) } else { debug_assert!(node.all(|id| id == Id::from(0))); - debug_assert!(eclass.nodes.windows(2).all(|w| w[0] < w[1])); - let mut start = eclass.nodes.binary_search(node).unwrap_or_else(|i| i); + debug_assert!(eclass.nodes().windows(2).all(|w| w[0] < w[1])); + let mut start = eclass.nodes().binary_search(node).unwrap_or_else(|i| i); let discrim = node.discriminant(); while start > 0 { - if eclass.nodes[start - 1].discriminant() == discrim { + if eclass.data.1[start - 1].discriminant() == discrim { start -= 1; } else { break; } } - let mut matching = eclass.nodes[start..] + let mut matching = eclass.nodes()[start..] .iter() .take_while(|&n| n.discriminant() == discrim) .filter(|n| node.matches(n)); debug_assert_eq!( matching.clone().count(), - eclass.nodes.iter().filter(|n| node.matches(n)).count(), + eclass.iter().filter(|n| node.matches(n)).count(), "matching node {:?}\nstart={}\n{:?} != {:?}\nnodes: {:?}", node, start, matching.clone().collect::>(), eclass - .nodes .iter() .filter(|n| node.matches(n)) .collect::>(), - eclass.nodes + eclass.nodes() ); matching.try_for_each(&mut f) } @@ -90,7 +85,7 @@ impl Machine { fn run( &mut self, - egraph: &EGraph, + egraph: &EMGraph, instructions: &[Instruction], subst: &Subst, yield_fn: &mut impl FnMut(&Self, &Subst) -> Result, @@ -343,7 +338,7 @@ impl Program { pub fn run_with_limit( &self, - egraph: &EGraph, + egraph: &EMGraph, eclass: Id, mut limit: usize, ) -> Vec diff --git a/src/macros.rs b/src/macros.rs index 58afca0e..22da3ba5 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -234,7 +234,7 @@ the outermost, and the last condition being the innermost. # Example ``` -# use egg::*; +# use egg::legacy::*; use std::borrow::Cow; use std::sync::Arc; define_language! { @@ -247,7 +247,7 @@ define_language! { } } -type EGraph = egg::EGraph; +type EGraph = egg::legacy::EGraph; let mut rules: Vec> = vec![ rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), @@ -271,7 +271,7 @@ rules.extend(vec![ #[derive(Debug)] struct MySillyApplier(&'static str); -impl Applier for MySillyApplier { +impl Applier> for MySillyApplier { fn apply_one(&self, _: &mut EGraph, _: Id, _: &Subst, _: Option<&PatternAst>, _: Symbol) -> Vec { panic!() } @@ -281,7 +281,7 @@ impl Applier for MySillyApplier { fn is_not_zero(var: &'static str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var = var.parse().unwrap(); let zero = SimpleLanguage::Num(0); - move |egraph, _, subst| !egraph[subst[var]].nodes.contains(&zero) + move |egraph, _, subst| !egraph[subst[var]].nodes().contains(&zero) } ``` diff --git a/src/multipattern.rs b/src/multipattern.rs index 4fe61212..222e037e 100644 --- a/src/multipattern.rs +++ b/src/multipattern.rs @@ -31,7 +31,7 @@ impl MultiPattern { /// Creates a new multipattern, binding the given patterns to the corresponding variables. /// /// ``` - /// use egg::*; + /// use egg::legacy::*; /// /// let mut egraph = EGraph::::default(); /// egraph.add_expr(&"(f a a)".parse().unwrap()); @@ -103,7 +103,7 @@ impl FromStr for MultiPattern { impl> Searcher for MultiPattern { fn search_eclass_with_limit( &self, - egraph: &EGraph, + egraph: &EMGraph, eclass: Id, limit: usize, ) -> Option> { @@ -138,7 +138,7 @@ impl> Searcher for MultiPattern { impl> Applier for MultiPattern { fn apply_one( &self, - _egraph: &mut EGraph, + _egraph: &mut EMGraph, _eclass: Id, _subst: &Subst, _searcher_ast: Option<&PatternAst>, @@ -149,7 +149,7 @@ impl> Applier for MultiPattern { fn apply_matches( &self, - egraph: &mut EGraph, + egraph: &mut EMGraph, matches: &[SearchMatches], _rule_name: Symbol, ) -> Vec { @@ -197,9 +197,9 @@ impl> Applier for MultiPattern { #[cfg(test)] mod tests { - use crate::{SymbolLang as S, *}; + use crate::legacy::{SymbolLang as S, *}; - type EGraph = crate::EGraph; + type EGraph = crate::legacy::EGraph; impl EGraph { fn add_string(&mut self, s: &str) -> Id { diff --git a/src/pattern.rs b/src/pattern.rs index 3143e82d..14fe7de9 100644 --- a/src/pattern.rs +++ b/src/pattern.rs @@ -35,7 +35,7 @@ use crate::*; /// This is probably how you'll create most [`Pattern`]s. /// /// ``` -/// use egg::*; +/// use egg::legacy::*; /// define_language! { /// enum Math { /// Num(i32), @@ -285,11 +285,11 @@ impl> Searcher for Pattern { Some(&self.ast) } - fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Vec> { + fn search_with_limit(&self, egraph: &EMGraph, limit: usize) -> Vec> { match self.ast.as_ref().last().unwrap() { ENodeOrVar::ENode(e) => { let key = e.discriminant(); - match egraph.classes_by_op.get(&key) { + match egraph.analysis.1.classes_by_op.get(&key) { None => vec![], Some(ids) => rewrite::search_eclasses_with_limit( self, @@ -310,7 +310,7 @@ impl> Searcher for Pattern { fn search_eclass_with_limit( &self, - egraph: &EGraph, + egraph: &EMGraph, eclass: Id, limit: usize, ) -> Option> { @@ -343,7 +343,7 @@ where fn apply_matches( &self, - egraph: &mut EGraph, + egraph: &mut EMGraph, matches: &[SearchMatches], rule_name: Symbol, ) -> Vec { @@ -375,7 +375,7 @@ where fn apply_one( &self, - egraph: &mut EGraph, + egraph: &mut EMGraph, eclass: Id, subst: &Subst, searcher_ast: Option<&PatternAst>, @@ -408,7 +408,7 @@ where pub(crate) fn apply_pat>( ids: &mut [Id], pat: &[ENodeOrVar], - egraph: &mut EGraph, + egraph: &mut EMGraph, subst: &Subst, ) -> Id { debug_assert_eq!(pat.len(), ids.len()); @@ -434,7 +434,7 @@ mod tests { use crate::{SymbolLang as S, *}; - type EGraph = crate::EGraph; + type EGraph = crate::legacy::EGraph; #[test] fn simple_match() { diff --git a/src/rewrite.rs b/src/rewrite.rs index 6b34b48b..127cb6f9 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -3,7 +3,6 @@ use std::fmt::{self, Debug, Display}; use std::sync::Arc; use crate::*; - /// A rewrite that searches for the lefthand side and applies the righthand side. /// /// The [`rewrite!`] macro is the easiest way to create rewrites. @@ -80,28 +79,28 @@ impl> Rewrite { /// Call [`search`] on the [`Searcher`]. /// /// [`search`]: Searcher::search() - pub fn search(&self, egraph: &EGraph) -> Vec> { + pub fn search(&self, egraph: &EMGraph) -> Vec> { self.searcher.search(egraph) } /// Call [`search_with_limit`] on the [`Searcher`]. /// /// [`search_with_limit`]: Searcher::search_with_limit() - pub fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Vec> { + pub fn search_with_limit(&self, egraph: &EMGraph, limit: usize) -> Vec> { self.searcher.search_with_limit(egraph, limit) } /// Call [`apply_matches`] on the [`Applier`]. /// /// [`apply_matches`]: Applier::apply_matches() - pub fn apply(&self, egraph: &mut EGraph, matches: &[SearchMatches]) -> Vec { + pub fn apply(&self, egraph: &mut EMGraph, matches: &[SearchMatches]) -> Vec { self.applier.apply_matches(egraph, matches, self.name) } /// This `run` is for testing use only. You should use things /// from the `egg::run` module #[cfg(test)] - pub(crate) fn run(&self, egraph: &mut EGraph) -> Vec { + pub(crate) fn run(&self, egraph: &mut EMGraph) -> Vec { let start = crate::util::Instant::now(); let matches = self.search(egraph); @@ -125,7 +124,7 @@ impl> Rewrite { /// Searches the given list of e-classes with a limit. pub(crate) fn search_eclasses_with_limit<'a, I, S, L, N>( searcher: &'a S, - egraph: &EGraph, + egraph: &EMGraph, eclasses: I, mut limit: usize, ) -> Vec> @@ -166,7 +165,7 @@ where { /// Search one eclass, returning None if no matches can be found. /// This should not return a SearchMatches with no substs. - fn search_eclass(&self, egraph: &EGraph, eclass: Id) -> Option> { + fn search_eclass(&self, egraph: &EMGraph, eclass: Id) -> Option> { self.search_eclass_with_limit(egraph, eclass, usize::MAX) } @@ -179,17 +178,17 @@ where /// [`search_eclass_with_limit`]: Searcher::search_eclass_with_limit fn search_eclass_with_limit( &self, - egraph: &EGraph, + egraph: &EMGraph, eclass: Id, limit: usize, ) -> Option>; - /// Search the whole [`EGraph`], returning a list of all the + /// Search the whole [`EMGraph`], returning a list of all the /// [`SearchMatches`] where something was found. /// This just calls [`search_eclass`] on each eclass. /// /// [`search_eclass`]: Searcher::search_eclass - fn search(&self, egraph: &EGraph) -> Vec> { + fn search(&self, egraph: &EMGraph) -> Vec> { egraph .classes() .filter_map(|e| self.search_eclass(egraph, e.id)) @@ -199,12 +198,12 @@ where /// Similar to [`search`], but return at most `limit` many matches. /// /// [`search`]: Searcher::search - fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Vec> { + fn search_with_limit(&self, egraph: &EMGraph, limit: usize) -> Vec> { search_eclasses_with_limit(self, egraph, egraph.classes().map(|e| e.id), limit) } /// Returns the number of matches in the e-graph - fn n_matches(&self, egraph: &EGraph) -> usize { + fn n_matches(&self, egraph: &EMGraph) -> usize { self.search(egraph).iter().map(|m| m.substs.len()).sum() } @@ -232,7 +231,7 @@ where /// /// # Example /// ``` -/// use egg::{rewrite as rw, *}; +/// use egg::legacy::{rewrite as rw, *}; /// use std::sync::Arc; /// /// define_language! { @@ -244,19 +243,22 @@ where /// } /// } /// -/// type EGraph = egg::EGraph; +/// type EGraph = egg::legacy::EGraph; /// /// // Our metadata in this case will be size of the smallest /// // represented expression in the eclass. /// #[derive(Default)] /// struct MinSize; +/// impl AnalysisData for MinSize { +/// type Data = usize; +/// } +/// /// impl Analysis for MinSize { -/// type Data = usize; /// fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { /// merge_min(to, from) /// } -/// fn make(egraph: &mut EGraph, enode: &Math) -> Self::Data { -/// let get_size = |i: Id| egraph[i].data; +/// fn make>(egraph: E, enode: &Math) -> Self::Data { +/// let get_size = |i: Id| *egraph.data(i); /// AstSize.cost(enode, get_size) /// } /// } @@ -285,13 +287,13 @@ where /// ast: PatternAst, /// } /// -/// impl Applier for Funky { +/// impl Applier> for Funky { /// /// fn apply_one(&self, egraph: &mut EGraph, matched_id: Id, subst: &Subst, searcher_pattern: Option<&PatternAst>, rule_name: Symbol) -> Vec { /// let a: Id = subst[self.a]; /// // In a custom Applier, you can inspect the analysis data, /// // which is powerful combination! -/// let size_of_a = egraph[a].data; +/// let size_of_a = egraph[a].data.0; /// if size_of_a > 50 { /// println!("Too big! Not doing anything"); /// vec![] @@ -337,7 +339,7 @@ where /// [`apply_one`]: Applier::apply_one() fn apply_matches( &self, - egraph: &mut EGraph, + egraph: &mut EMGraph, matches: &[SearchMatches], rule_name: Symbol, ) -> Vec { @@ -375,7 +377,7 @@ where /// [`apply_matches`]: Applier::apply_matches() fn apply_one( &self, - egraph: &mut EGraph, + egraph: &mut EMGraph, eclass: Id, subst: &Subst, searcher_ast: Option<&PatternAst>, @@ -429,7 +431,7 @@ where fn apply_one( &self, - egraph: &mut EGraph, + egraph: &mut EMGraph, eclass: Id, subst: &Subst, searcher_ast: Option<&PatternAst>, @@ -469,7 +471,7 @@ where /// `eclass` is the eclass [`Id`] where the match (`subst`) occured. /// If this is true, then the [`ConditionalApplier`] will fire. /// - fn check(&self, egraph: &mut EGraph, eclass: Id, subst: &Subst) -> bool; + fn check(&self, egraph: &mut EMGraph, eclass: Id, subst: &Subst) -> bool; /// Returns a list of variables that this Condition assumes are bound. /// @@ -486,9 +488,9 @@ impl Condition for F where L: Language, N: Analysis, - F: Fn(&mut EGraph, Id, &Subst) -> bool, + F: Fn(&mut EMGraph, Id, &Subst) -> bool, { - fn check(&self, egraph: &mut EGraph, eclass: Id, subst: &Subst) -> bool { + fn check(&self, egraph: &mut EMGraph, eclass: Id, subst: &Subst) -> bool { self(egraph, eclass, subst) } } @@ -528,7 +530,7 @@ where L: Language, N: Analysis, { - fn check(&self, egraph: &mut EGraph, _eclass: Id, subst: &Subst) -> bool { + fn check(&self, egraph: &mut EMGraph, _eclass: Id, subst: &Subst) -> bool { let mut id_buf_1 = vec![0.into(); self.p1.ast.as_ref().len()]; let mut id_buf_2 = vec![0.into(); self.p2.ast.as_ref().len()]; let a1 = apply_pat(&mut id_buf_1, self.p1.ast.as_ref(), egraph, subst); @@ -549,7 +551,7 @@ mod tests { use crate::{SymbolLang as S, *}; use std::str::FromStr; - type EGraph = crate::EGraph; + type EGraph = crate::legacy::EGraph; #[test] fn conditional_rewrite() { @@ -600,7 +602,7 @@ mod tests { let root = egraph.add_expr(&start); fn get(egraph: &EGraph, id: Id) -> Symbol { - egraph[id].nodes[0].op + egraph[id].nodes()[0].op } #[derive(Debug)] @@ -608,7 +610,7 @@ mod tests { _rhs: PatternAst, } - impl Applier for Appender { + impl Applier> for Appender { fn apply_one( &self, egraph: &mut EGraph, diff --git a/src/run.rs b/src/run.rs index ae78d84e..264405ee 100644 --- a/src/run.rs +++ b/src/run.rs @@ -136,7 +136,7 @@ println!( */ pub struct Runner, IterData = ()> { /// The [`EGraph`] used. - pub egraph: EGraph, + pub egraph: EMGraph, /// Data accumulated over each [`Iteration`]. pub iterations: Vec>, /// The roots of expressions added by the @@ -318,7 +318,7 @@ where node_limit: 10_000, time_limit: Duration::from_secs(5), - egraph: EGraph::new(analysis), + egraph: EGraph::new((analysis, EMatchingAnalysis::default())), roots: vec![], iterations: vec![], stop_reason: None, @@ -395,7 +395,7 @@ where } /// Replace the [`EGraph`] of this `Runner`. - pub fn with_egraph(self, egraph: EGraph) -> Self { + pub fn with_egraph(self, egraph: EMGraph) -> Self { Self { egraph, ..self } } @@ -686,7 +686,7 @@ where fn search_rewrite<'a>( &mut self, iteration: usize, - egraph: &EGraph, + egraph: &EMGraph, rewrite: &'a Rewrite, ) -> Vec> { rewrite.search(egraph) @@ -701,7 +701,7 @@ where fn apply_rewrite( &mut self, iteration: usize, - egraph: &mut EGraph, + egraph: &mut EMGraph, rewrite: &Rewrite, matches: Vec>, ) -> usize { @@ -866,7 +866,7 @@ where fn search_rewrite<'a>( &mut self, iteration: usize, - egraph: &EGraph, + egraph: &EMGraph, rewrite: &'a Rewrite, ) -> Vec> { let stats = self.rule_stats(rewrite.name); diff --git a/src/test.rs b/src/test.rs index 10815d66..a34f01d2 100644 --- a/src/test.rs +++ b/src/test.rs @@ -7,7 +7,7 @@ use std::{fmt::Display, fs::File, io::Write, path::PathBuf}; use saturating::Saturating; -use crate::*; +use crate::legacy::*; pub fn env_var(s: &str) -> Option where diff --git a/src/tutorials/_02_getting_started.rs b/src/tutorials/_02_getting_started.rs index 895d63f5..512680cd 100644 --- a/src/tutorials/_02_getting_started.rs +++ b/src/tutorials/_02_getting_started.rs @@ -102,7 +102,7 @@ We'll use a [`Pattern`], which implements the [`Searcher`] trait, to search the e-graph for matches: ``` -# use egg::*; +# use egg::legacy::*; // let's make an e-graph let mut egraph: EGraph = Default::default(); let a = egraph.add(SymbolLang::leaf("a")); diff --git a/tests/datalog.rs b/tests/datalog.rs index 2343d681..b06db0b7 100644 --- a/tests/datalog.rs +++ b/tests/datalog.rs @@ -1,4 +1,4 @@ -use egg::*; +use egg::legacy::*; define_language! { enum Lang { diff --git a/tests/lambda.rs b/tests/lambda.rs index 80ea4fbd..96dab3e8 100644 --- a/tests/lambda.rs +++ b/tests/lambda.rs @@ -1,4 +1,4 @@ -use egg::{rewrite as rw, *}; +use egg::legacy::{rewrite as rw, *}; use fxhash::FxHashSet as HashSet; define_language! { @@ -31,7 +31,7 @@ impl Lambda { } } -type EGraph = egg::EGraph; +type EGraph = egg::legacy::EGraph; #[derive(Default)] struct LambdaAnalysis; @@ -42,8 +42,11 @@ struct Data { constant: Option<(Lambda, PatternAst)>, } -fn eval(egraph: &EGraph, enode: &Lambda) -> Option<(Lambda, PatternAst)> { - let x = |i: &Id| egraph[*i].data.constant.as_ref().map(|c| &c.0); +fn eval>( + egraph: &E, + enode: &Lambda, +) -> Option<(Lambda, PatternAst)> { + let x = |i: &Id| egraph.data(*i).constant.as_ref().map(|c| &c.0); match enode { Lambda::Num(n) => Some((enode.clone(), format!("{}", n).parse().unwrap())), Lambda::Bool(b) => Some((enode.clone(), format!("{}", b).parse().unwrap())), @@ -59,8 +62,11 @@ fn eval(egraph: &EGraph, enode: &Lambda) -> Option<(Lambda, PatternAst)> } } -impl Analysis for LambdaAnalysis { +impl AnalysisData for LambdaAnalysis { type Data = Data; +} + +impl Analysis for LambdaAnalysis { fn merge(&mut self, to: &mut Data, from: Data) -> DidMerge { let before_len = to.free.len(); // to.free.extend(from.free); @@ -75,8 +81,8 @@ impl Analysis for LambdaAnalysis { }) } - fn make(egraph: &mut EGraph, enode: &Lambda) -> Data { - let f = |i: &Id| egraph[*i].data.free.iter().cloned(); + fn make>(egraph: E, enode: &Lambda) -> Data { + let f = |i: &Id| egraph.data(*i).free.iter().cloned(); let mut free = HashSet::default(); match enode { Lambda::Var(v) => { @@ -91,14 +97,14 @@ impl Analysis for LambdaAnalysis { free.extend(f(a)); free.remove(v); } - _ => enode.for_each(|c| free.extend(&egraph[c].data.free)), + _ => enode.for_each(|c| free.extend(&egraph.data(c).free)), } - let constant = eval(egraph, enode); + let constant = eval(&egraph, enode); Data { constant, free } } - fn modify(egraph: &mut EGraph, id: Id) { - if let Some(c) = egraph[id].data.constant.clone() { + fn modify>(mut egraph: E, id: Id) { + if let Some(c) = egraph.data(id).constant.clone() { if egraph.are_explanations_enabled() { egraph.union_instantiations( &c.0.to_string().parse().unwrap(), @@ -123,7 +129,7 @@ fn is_not_same_var(v1: Var, v2: Var) -> impl Fn(&mut EGraph, Id, &Subst) -> bool } fn is_const(v: Var) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { - move |egraph, _, subst| egraph[subst[v]].data.constant.is_some() + move |egraph, _, subst| egraph[subst[v]].data.0.constant.is_some() } fn rules() -> Vec> { @@ -171,7 +177,7 @@ struct CaptureAvoid { if_free: Pattern, } -impl Applier for CaptureAvoid { +impl Applier> for CaptureAvoid { fn apply_one( &self, egraph: &mut EGraph, @@ -182,7 +188,7 @@ impl Applier for CaptureAvoid { ) -> Vec { let e = subst[self.e]; let v2 = subst[self.v2]; - let v2_free_in_e = egraph[e].data.free.contains(&v2); + let v2_free_in_e = egraph[e].data.0.free.contains(&v2); if v2_free_in_e { let mut subst = subst.clone(); let sym = Lambda::Symbol(format!("_{}", eclass).into()); diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..2010c78f 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -1,8 +1,8 @@ -use egg::{rewrite as rw, *}; +use egg::legacy::{rewrite as rw, *}; use ordered_float::NotNan; -pub type EGraph = egg::EGraph; -pub type Rewrite = egg::Rewrite; +pub type EGraph = egg::legacy::EGraph; +pub type Rewrite = egg::legacy::Rewrite; pub type Constant = NotNan; @@ -47,11 +47,14 @@ impl egg::CostFunction for MathCostFn { #[derive(Default)] pub struct ConstantFold; -impl Analysis for ConstantFold { + +impl AnalysisData for ConstantFold { type Data = Option<(Constant, PatternAst)>; +} - fn make(egraph: &mut EGraph, enode: &Math) -> Self::Data { - let x = |i: &Id| egraph[*i].data.as_ref().map(|d| d.0); +impl Analysis for ConstantFold { + fn make>(egraph: E, enode: &Math) -> Self::Data { + let x = |i: &Id| egraph.data(*i).as_ref().map(|d| d.0); Some(match enode { Math::Constant(c) => (*c, format!("{}", c).parse().unwrap()), Math::Add([a, b]) => ( @@ -66,7 +69,7 @@ impl Analysis for ConstantFold { x(a)? * x(b)?, format!("(* {} {})", x(a)?, x(b)?).parse().unwrap(), ), - Math::Div([a, b]) if x(b) != Some(NotNan::new(0.0).unwrap()) => ( + Math::Div([a, b]) if x(b) != Some(Constant::new(0.0).unwrap()) => ( x(a)? / x(b)?, format!("(/ {} {})", x(a)?, x(b)?).parse().unwrap(), ), @@ -81,8 +84,8 @@ impl Analysis for ConstantFold { }) } - fn modify(egraph: &mut EGraph, id: Id) { - let data = egraph[id].data.clone(); + fn modify>(mut egraph: E, id: Id) { + let data = egraph.data(id).clone(); if let Some((c, pat)) = data { if egraph.are_explanations_enabled() { egraph.union_instantiations( @@ -96,10 +99,10 @@ impl Analysis for ConstantFold { egraph.union(id, added); } // to not prune, comment this out - egraph[id].nodes.retain(|n| n.is_leaf()); - - #[cfg(debug_assertions)] - egraph[id].assert_unique_leaves(); + // egraph[id].nodes.retain(|n| n.is_leaf()); + // + // #[cfg(debug_assertions)] + // egraph[id].assert_unique_leaves(); } } } @@ -109,9 +112,8 @@ fn is_const_or_distinct_var(v: &str, w: &str) -> impl Fn(&mut EGraph, Id, &Subst let w = w.parse().unwrap(); move |egraph, _, subst| { egraph.find(subst[v]) != egraph.find(subst[w]) - && (egraph[subst[v]].data.is_some() + && (egraph[subst[v]].data.0.is_some() || egraph[subst[v]] - .nodes .iter() .any(|n| matches!(n, Math::Symbol(..)))) } @@ -119,14 +121,13 @@ fn is_const_or_distinct_var(v: &str, w: &str) -> impl Fn(&mut EGraph, Id, &Subst fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var = var.parse().unwrap(); - move |egraph, _, subst| egraph[subst[var]].data.is_some() + move |egraph, _, subst| egraph[subst[var]].data.0.is_some() } fn is_sym(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var = var.parse().unwrap(); move |egraph, _, subst| { egraph[subst[var]] - .nodes .iter() .any(|n| matches!(n, Math::Symbol(..))) } @@ -135,7 +136,7 @@ fn is_sym(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var = var.parse().unwrap(); move |egraph, _, subst| { - if let Some(n) = &egraph[subst[var]].data { + if let Some(n) = &egraph[subst[var]].data.0 { *(n.0) != 0.0 } else { true @@ -410,8 +411,8 @@ fn math_ematching_bench() { #[test] fn test_basic_egraph_union_intersect() { - let mut egraph1 = EGraph::new(ConstantFold {}).with_explanations_enabled(); - let mut egraph2 = EGraph::new(ConstantFold {}).with_explanations_enabled(); + let mut egraph1 = EGraph::default().with_explanations_enabled(); + let mut egraph2 = EGraph::default().with_explanations_enabled(); egraph1.union_instantiations( &"x".parse().unwrap(), &"y".parse().unwrap(), @@ -437,7 +438,7 @@ fn test_basic_egraph_union_intersect() { "", ); - let mut egraph3 = egraph1.egraph_intersect(&egraph2, ConstantFold {}); + let mut egraph3 = egraph1.egraph_intersect(&egraph2, Default::default()); egraph2.egraph_union(&egraph1); @@ -479,8 +480,8 @@ fn test_basic_egraph_union_intersect() { #[test] fn test_intersect_basic() { - let mut egraph1 = EGraph::new(ConstantFold {}).with_explanations_enabled(); - let mut egraph2 = EGraph::new(ConstantFold {}).with_explanations_enabled(); + let mut egraph1 = EGraph::default().with_explanations_enabled(); + let mut egraph2 = EGraph::default().with_explanations_enabled(); egraph1.union_instantiations( &"(+ x 0)".parse().unwrap(), &"(+ y 0)".parse().unwrap(), @@ -496,7 +497,7 @@ fn test_intersect_basic() { egraph2.add_expr(&"(+ x 0)".parse().unwrap()); egraph2.add_expr(&"(+ y 0)".parse().unwrap()); - let mut egraph3 = egraph1.egraph_intersect(&egraph2, ConstantFold {}); + let mut egraph3 = egraph1.egraph_intersect(&egraph2, Default::default()); assert_ne!( egraph3.add_expr(&"x".parse().unwrap()), diff --git a/tests/prop.rs b/tests/prop.rs index ed1c7469..faec092f 100644 --- a/tests/prop.rs +++ b/tests/prop.rs @@ -1,4 +1,4 @@ -use egg::*; +use egg::legacy::*; define_language! { enum Prop { @@ -11,13 +11,17 @@ define_language! { } } -type EGraph = egg::EGraph; -type Rewrite = egg::Rewrite; +type EGraph = egg::legacy::EGraph; +type Rewrite = egg::legacy::Rewrite; #[derive(Default)] struct ConstantFold; -impl Analysis for ConstantFold { + +impl AnalysisData for ConstantFold { type Data = Option<(bool, PatternAst)>; +} + +impl Analysis for ConstantFold { fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { merge_option(to, from, |a, b| { assert_eq!(a.0, b.0, "Merged non-equal constants"); @@ -25,8 +29,8 @@ impl Analysis for ConstantFold { }) } - fn make(egraph: &mut EGraph, enode: &Prop) -> Self::Data { - let x = |i: &Id| egraph[*i].data.as_ref().map(|c| c.0); + fn make>(egraph: E, enode: &Prop) -> Self::Data { + let x = |i: &Id| egraph.data(*i).as_ref().map(|c| c.0); let result = match enode { Prop::Bool(c) => Some((*c, c.to_string().parse().unwrap())), Prop::Symbol(_) => None, @@ -48,8 +52,8 @@ impl Analysis for ConstantFold { result } - fn modify(egraph: &mut EGraph, id: Id) { - if let Some(c) = egraph[id].data.clone() { + fn modify>(mut egraph: E, id: Id) { + if let Some(c) = egraph.data(id).clone() { egraph.union_instantiations( &c.1, &c.0.to_string().parse().unwrap(),