Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/eclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub struct EClass<L, D> {
/// Modifying this field will _not_ cause changes to propagate through the e-graph.
/// Prefer [`EGraph::set_analysis_data`] instead.
pub data: D,
/// The parent enodes and their original Ids.
pub(crate) parents: Vec<(L, Id)>,
/// The original Ids of parent enodes.
pub(crate) parents: Vec<Id>,
}

impl<L, D> EClass<L, D> {
Expand All @@ -37,9 +37,9 @@ impl<L, D> EClass<L, D> {
self.nodes.iter()
}

/// Iterates over the parent enodes of this eclass.
pub fn parents(&self) -> impl ExactSizeIterator<Item = (&L, Id)> {
self.parents.iter().map(|(node, id)| (node, *id))
/// Iterates over the non-canonical ids of parent enodes of this eclass.
pub fn parents(&self) -> impl ExactSizeIterator<Item = Id> + '_ {
self.parents.iter().copied()
}
}

Expand Down
174 changes: 127 additions & 47 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,17 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
/// The `Explain` used to explain equivalences in this `EGraph`.
pub(crate) explain: Option<Explain<L>>,
unionfind: UnionFind,
/// Stores the original node represented by each non-canonical id
nodes: Vec<L>,
/// Stores each enode's `Id`, not the `Id` of the eclass.
/// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new
/// unions can cause them to become out of date.
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
memo: HashMap<L, Id>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pending: Vec<(L, Id)>,
analysis_pending: UniqueQueue<(L, Id)>,
pending: Vec<Id>,
analysis_pending: UniqueQueue<Id>,
#[cfg_attr(
feature = "serde-1",
serde(bound(
Expand Down Expand Up @@ -114,6 +116,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
analysis,
classes: Default::default(),
unionfind: Default::default(),
nodes: Default::default(),
clean: false,
explain: None,
pending: Default::default(),
Expand Down Expand Up @@ -214,12 +217,14 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Make a copy of the egraph with the same nodes, but no unions between them.
pub fn copy_without_unions(&self, analysis: N) -> Self {
if let Some(explain) = &self.explain {
let egraph = Self::new(analysis);
explain.populate_enodes(egraph)
} else {
if self.explain.is_none() {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions");
}
let mut egraph = Self::new(analysis);
for node in &self.nodes {
egraph.add(node.clone());
}
egraph
}

/// Performs the union between two egraphs.
Expand Down Expand Up @@ -339,32 +344,70 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical),
/// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical))
pub fn id_to_expr(&self, id: Id) -> RecExpr<L> {
if let Some(explain) = &self.explain {
explain.node_to_recexpr(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
let mut res = Default::default();
let mut cache = Default::default();
self.id_to_expr_internal(&mut res, id, &mut cache);
res
}

fn id_to_expr_internal(
&self,
res: &mut RecExpr<L>,
node_id: Id,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let new_node = self
.id_to_node(node_id)
.clone()
.map_children(|child| self.id_to_expr_internal(res, child, cache));
let res_id = res.add(new_node);
cache.insert(node_id, res_id);
res_id
}

/// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep
pub fn id_to_node(&self, id: Id) -> &L {
if let Some(explain) = &self.explain {
explain.node(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
}
&self.nodes[usize::from(id)]
}

/// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term.
/// When an eclass listed in the given substitutions is found, it creates a variable.
/// It also adds this variable and the corresponding Id value to the resulting [`Subst`]
/// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr).
pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap<Id, Id>) -> (Pattern<L>, Subst) {
if let Some(explain) = &self.explain {
explain.node_to_pattern(id, substitutions)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id");
let mut res = Default::default();
let mut subst = Default::default();
let mut cache = Default::default();
self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache);
(Pattern::new(res), subst)
}

fn id_to_pattern_internal(
&self,
res: &mut PatternAst<L>,
node_id: Id,
var_substitutions: &HashMap<Id, Id>,
subst: &mut Subst,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let res_id = if let Some(existing) = var_substitutions.get(&node_id) {
let var = format!("?{}", node_id).parse().unwrap();
subst.insert(var, *existing);
res.add(ENodeOrVar::Var(var))
} else {
let new_node = self.id_to_node(node_id).clone().map_children(|child| {
self.id_to_pattern_internal(res, child, var_substitutions, subst, cache)
});
res.add(ENodeOrVar::ENode(new_node))
};
cache.insert(node_id, res_id);
res_id
}

/// Get all the unions ever found in the egraph in terms of enode ids.
Expand All @@ -390,17 +433,19 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Get the number of congruences between nodes in the egraph.
/// Only available when explanations are enabled.
pub fn get_num_congr(&mut self) -> usize {
if let Some(explain) = &self.explain {
explain.get_num_congr::<N>(&self.classes, &self.unionfind)
if let Some(explain) = &mut self.explain {
explain
.with_nodes(&self.nodes)
.get_num_congr::<N>(&self.classes, &self.unionfind)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Get the number of nodes in the egraph used for explanations.
pub fn get_explanation_num_nodes(&mut self) -> usize {
if let Some(explain) = &self.explain {
explain.get_num_nodes()
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).get_num_nodes()
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand Down Expand Up @@ -438,7 +483,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
);
}
if let Some(explain) = &mut self.explain {
explain.explain_equivalence::<N>(left, right, &mut self.unionfind, &self.classes)
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -461,7 +511,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// but more efficient
fn explain_existance_id(&mut self, id: Id) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -475,7 +525,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
) -> Explanation<L> {
let id = self.add_instantiation_noncanonical(pattern, subst);
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -498,7 +548,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
);
}
if let Some(explain) = &mut self.explain {
explain.explain_equivalence::<N>(left, right, &mut self.unionfind, &self.classes)
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.");
}
Expand Down Expand Up @@ -586,7 +641,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical
///
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr`
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
Expand Down Expand Up @@ -624,7 +679,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// canonical
///
/// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the
/// instantiation of the pattern
fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let nodes = pat.as_ref();
Expand Down Expand Up @@ -744,7 +799,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
/// correspond to the parameter `enode`
///
/// # Example
/// ## Example
/// ```
/// # use egg::*;
/// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_enabled();
Expand All @@ -759,6 +814,25 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
/// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap());
/// ```
///
/// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
/// produce an expression with equivalent but not necessarily identical children
///
/// # Example
/// ```
/// # use egg::*;
/// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_disabled();
/// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
/// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
/// egraph.union(a, b);
/// egraph.rebuild();
///
/// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
/// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
///
/// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
/// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap());
/// ```
pub fn add_uncanonical(&mut self, mut enode: L) -> Id {
let original = enode.clone();
if let Some(existing_id) = self.lookup_internal(&mut enode) {
Expand All @@ -769,7 +843,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
*existing_explain
} else {
let new_id = self.unionfind.make_set();
explain.add(original, new_id, new_id);
explain.add(original.clone(), new_id, new_id);
debug_assert_eq!(Id::from(self.nodes.len()), new_id);
self.nodes.push(original);
self.unionfind.union(id, new_id);
explain.union(existing_id, new_id, Justification::Congruence, true);
new_id
Expand All @@ -778,7 +854,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
existing_id
}
} else {
let id = self.make_new_eclass(enode);
let id = self.make_new_eclass(enode, original.clone());
if let Some(explain) = self.explain.as_mut() {
explain.add(original, id, id);
}
Expand All @@ -791,24 +867,26 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}

/// This function makes a new eclass in the egraph (but doesn't touch explanations)
fn make_new_eclass(&mut self, enode: L) -> Id {
fn make_new_eclass(&mut self, enode: L, original: L) -> Id {
let id = self.unionfind.make_set();
log::trace!(" ...adding to {}", id);
let class = EClass {
id,
nodes: vec![enode.clone()],
data: N::make(self, &enode),
data: N::make(self, &original),
parents: Default::default(),
};

debug_assert_eq!(Id::from(self.nodes.len()), id);
self.nodes.push(original);

// add this enode to the parent lists of its children
enode.for_each(|child| {
let tup = (enode.clone(), id);
self[child].parents.push(tup);
self[child].parents.push(id);
});

// TODO is this needed?
self.pending.push((enode.clone(), id));
self.pending.push(id);

self.classes.insert(id, class);
assert!(self.memo.insert(enode, id).is_none());
Expand Down Expand Up @@ -943,13 +1021,13 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let class1 = self.classes.get_mut(&id1).unwrap();
assert_eq!(id1, class1.id);

self.pending.extend(class2.parents.iter().cloned());
self.pending.extend(class2.parents.iter().copied());
let did_merge = self.analysis.merge(&mut class1.data, class2.data);
if did_merge.0 {
self.analysis_pending.extend(class1.parents.iter().cloned());
self.analysis_pending.extend(class1.parents.iter().copied());
}
if did_merge.1 {
self.analysis_pending.extend(class2.parents.iter().cloned());
self.analysis_pending.extend(class2.parents.iter().copied());
}

concat_vecs(&mut class1.nodes, class2.nodes);
Expand All @@ -968,7 +1046,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let id = self.find_mut(id);
let class = self.classes.get_mut(&id).unwrap();
class.data = new_data;
self.analysis_pending.extend(class.parents.iter().cloned());
self.analysis_pending.extend(class.parents.iter().copied());
N::modify(self, id)
}

Expand Down Expand Up @@ -1103,7 +1181,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let mut n_unions = 0;

while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
while let Some((mut node, class)) = self.pending.pop() {
while let Some(class) = self.pending.pop() {
let mut node = self.nodes[usize::from(class)].clone();
node.update_children(|id| self.find_mut(id));
if let Some(memo_class) = self.memo.insert(node, class) {
let did_something = self.perform_union(
Expand All @@ -1116,14 +1195,15 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

while let Some((node, class_id)) = self.analysis_pending.pop() {
while let Some(class_id) = self.analysis_pending.pop() {
let node = self.nodes[usize::from(class_id)].clone();
let class_id = self.find_mut(class_id);
let node_data = N::make(self, &node);
let class = self.classes.get_mut(&class_id).unwrap();

let did_merge = self.analysis.merge(&mut class.data, node_data);
if did_merge.0 {
self.analysis_pending.extend(class.parents.iter().cloned());
self.analysis_pending.extend(class.parents.iter().copied());
N::modify(self, class_id)
}
}
Expand Down Expand Up @@ -1204,9 +1284,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
n_unions
}

pub(crate) fn check_each_explain(&self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &self.explain {
explain.check_each_explain(rules)
pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).check_each_explain(rules)
} else {
panic!("Can't check explain when explanations are off");
}
Expand Down
Loading