diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 67a03eb494..b027b9c731 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -277,7 +277,8 @@ impl std::fmt::Display for Wire { } /// Marks [FuncDefn](crate::ops::FuncDefn)s and [FuncDecl](crate::ops::FuncDecl)s as -/// to whether they should be considered for linking. +/// to whether they should be considered for linking, and as reachable (starting points) +/// for optimization/analysis. #[derive( Clone, Debug, diff --git a/hugr-passes/src/call_graph.rs b/hugr-passes/src/call_graph.rs index 26df84e92e..eed3033411 100644 --- a/hugr-passes/src/call_graph.rs +++ b/hugr-passes/src/call_graph.rs @@ -18,7 +18,7 @@ pub enum CallGraphNode { FuncDecl(N), /// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr FuncDefn(N), - /// petgraph-node corresponds to the root node of the hugr, that is not + /// petgraph-node corresponds to the entrypoint node of the hugr, that is not /// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module) /// either, as such a node could not have outgoing edges, so is not represented in the petgraph. NonFuncRoot, diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 9c450b0aca..bf0e3a6718 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -15,20 +15,27 @@ use hugr_core::{ }; use value_handle::ValueHandle; -use crate::dataflow::{ - ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination, - partial_from_const, -}; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; use crate::{ComposablePass, composable::validate_if_test}; +use crate::{ + VisPolicy, + dataflow::{ + ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination, + partial_from_const, + }, +}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. +/// +/// Note that by default we assume that only the entrypoint is reachable and +/// only if it is not the module root; see [Self::with_inputs]. Mutation +/// occurs anywhere beneath the entrypoint. pub struct ConstantFoldPass { allow_increase_termination: bool, /// Each outer key Node must be either: - /// - a `FuncDefn` child of the root, if the root is a module; or - /// - the root, if the root is not a Module + /// - a `FuncDefn` child of the module-root + /// - the entrypoint inputs: HashMap>, } @@ -36,9 +43,8 @@ pub struct ConstantFoldPass { #[non_exhaustive] /// Errors produced by [`ConstantFoldPass`]. pub enum ConstFoldError { - /// Error raised when a Node is specified as an entry-point but - /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor - /// a [Conditional](OpType::Conditional). + /// Error raised when inputs are provided for a Node that is neither a dataflow + /// parent, nor a [CFG](OpType::CFG), nor a [Conditional](OpType::Conditional). #[error("{node} has OpType {op} which cannot be an entry-point")] InvalidEntryPoint { /// The node which was specified as an entry-point @@ -46,7 +52,7 @@ pub enum ConstFoldError { /// The `OpType` of the node op: OpType, }, - /// The chosen entrypoint is not in the hugr. + /// Inputs were provided for a node that is not in the hugr. #[error("Entry-point {node} is not part of the Hugr")] MissingEntryPoint { /// The missing node @@ -67,15 +73,25 @@ impl ConstantFoldPass { } /// Specifies a number of external inputs to an entry point of the Hugr. - /// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` child of the root; - /// or for non-Module-rooted Hugrs, `node` is the root of the Hugr. (This is not + /// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` (child of the root); + /// for non-Module-rooted Hugrs, `node` is the [HugrView::entrypoint]. (This is not /// enforced, but it must be a container and not a module itself.) /// /// Multiple calls for the same entry-point combine their values, with later /// values on the same in-port replacing earlier ones. /// - /// Note that if `inputs` is empty, this still marks the node as an entry-point, i.e. - /// we must preserve nodes required to compute its result. + /// Note that providing empty `inputs` indicates that we must preserve the ability + /// to compute the result of `node` for all possible inputs. + /// * If the entrypoint is the module-root, this method should be called for every + /// [FuncDefn] that is externally callable + /// * Otherwise, i.e. if the entrypoint is not the module-root, + /// * The default is to assume the entrypoint is callable with any inputs; + /// * If `node` is the entrypoint, this method allows to restrict the possible inputs + /// * If `node` is beneath the entrypoint, this merely degrades the analysis. (We + /// will mutate only beneath the entrypoint, but using results of analysing the + /// whole Hugr wrt. the specified/any inputs too). + /// + /// [FuncDefn]: hugr_core::ops::FuncDefn pub fn with_inputs( mut self, node: Node, @@ -97,8 +113,7 @@ impl + 'static> ComposablePass for ConstantFoldPass { /// /// # Errors /// - /// [`ConstFoldError::InvalidEntryPoint`] if an entry-point added by [`Self::with_inputs`] - /// was of an invalid [`OpType`] + /// [ConstFoldError] if inputs were provided via [`Self::with_inputs`] for an invalid node. fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), @@ -184,25 +199,51 @@ impl + 'static> ComposablePass for ConstantFoldPass { } } +const NO_INPUTS: [(IncomingPort, Value); 0] = []; + /// Exhaustively apply constant folding to a HUGR. /// If the Hugr's entrypoint is its [`Module`], assumes all [`FuncDefn`] children are reachable. +/// Otherwise, assume that the [HugrView::entrypoint] is itself reachable. /// /// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn /// [`Module`]: hugr_core::ops::OpType::Module +#[deprecated(note = "Use fold_constants, or manually configure ConstantFoldPass")] pub fn constant_fold_pass + 'static>(mut h: impl AsMut) { let h = h.as_mut(); let c = ConstantFoldPass::default(); let c = if h.get_optype(h.entrypoint()).is_module() { - let no_inputs: [(IncomingPort, _); 0] = []; h.children(h.entrypoint()) .filter(|n| h.get_optype(*n).is_func_defn()) - .fold(c, |c, n| c.with_inputs(n, no_inputs.iter().cloned())) + .fold(c, |c, n| c.with_inputs(n, NO_INPUTS.clone())) } else { c }; validate_if_test(c, h).unwrap(); } +/// Exhaustively apply constant folding to a HUGR. +/// Assumes that the Hugr's entrypoint is reachable (if it is not a [`Module`]). +/// Also uses `policy` to determine which public [`FuncDefn`] children of the [`HugrView::module_root`] are reachable. +/// +/// [`Module`]: hugr_core::ops::OpType::Module +/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn +pub fn fold_constants(h: &mut (impl HugrMut + 'static), policy: VisPolicy) { + let mut funcs = Vec::new(); + if !h.entrypoint_optype().is_module() { + funcs.push(h.entrypoint()); + } + if policy.for_hugr(&h) { + funcs.extend( + h.children(h.module_root()) + .filter(|n| h.get_optype(*n).is_func_defn()), + ) + } + let c = funcs.into_iter().fold(ConstantFoldPass::default(), |c, n| { + c.with_inputs(n, NO_INPUTS.clone()) + }); + validate_if_test(c, h).unwrap(); +} + struct ConstFoldContext; impl ConstLoader> for ConstFoldContext { diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index a60684ec07..a74e9ab96d 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -29,10 +29,14 @@ use hugr_core::std_extensions::logic::LogicOp; use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row}; -use crate::ComposablePass as _; use crate::dataflow::{DFContext, PartialValue, partial_from_const}; +use crate::{ComposablePass as _, VisPolicy}; -use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass}; +use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, fold_constants}; + +fn constant_fold_pass(h: &mut (impl HugrMut + 'static)) { + fold_constants(h, VisPolicy::AllPublic); +} #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0368f931bc..898794cdf3 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -121,7 +121,7 @@ impl Machine { } else { let ep = self.0.entrypoint(); let mut p = in_values.into_iter().peekable(); - // We must provide some inputs to the root so that they are Top rather than Bottom. + // We must provide some inputs to the entrypoint so that they are Top rather than Bottom. // (However, this test will fail for DataflowBlock or Case roots, i.e. if no // inputs have been provided they will still see Bottom. We could store the "input" // values for even these nodes in self.1 and then convert to actual Wire values diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index 9be3eaa856..f788d9bacf 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,7 +1,7 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::hugr::internal::HugrInternals; -use hugr_core::{HugrView, hugr::hugrmut::HugrMut, ops::OpType}; +use hugr_core::{HugrView, Visibility, hugr::hugrmut::HugrMut, ops::OpType}; use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ @@ -9,17 +9,19 @@ use std::{ sync::Arc, }; -use crate::ComposablePass; +use crate::{ComposablePass, VisPolicy}; -/// Configuration for Dead Code Elimination pass +/// Configuration for Dead Code Elimination pass, i.e. which removes nodes +/// beneath the [HugrView::entrypoint] that compute only unneeded values. #[derive(Clone)] pub struct DeadCodeElimPass { /// Nodes that are definitely needed - e.g. `FuncDefns`, but could be anything. - /// Hugr Root is assumed to be an entry point even if not mentioned here. + /// [HugrView::entrypoint] is assumed to be needed even if not mentioned here. entry_points: Vec, /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [`PreserveNode::default_for`]. preserve_callback: Arc>, + include_exports: VisPolicy, } impl Default for DeadCodeElimPass { @@ -27,6 +29,7 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), + include_exports: VisPolicy::default(), } } } @@ -39,11 +42,13 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a, N> { entry_points: &'a Vec, + include_exports: VisPolicy, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, + include_exports: self.include_exports, }, f, ) @@ -69,12 +74,12 @@ pub enum PreserveNode { impl PreserveNode { /// A conservative default for a given node. Just examines the node's [`OpType`]: - /// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn`, but would - /// also need to check for cycles in the [`CallGraph`](super::call_graph::CallGraph).) + /// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn` for + /// termination, but would also need to check for cycles in the `CallGraph`.) /// * Assumes all CFGs must be preserved. (One could, for example, allow acyclic /// CFGs to be removed.) - /// * Assumes all `TailLoops` must be preserved. (One could, for example, use dataflow - /// analysis to allow removal of `TailLoops` that never [Continue](hugr_core::ops::TailLoop::CONTINUE_TAG).) + /// * Assumes all `TailLoops` must be preserved. (One could use some analysis, e.g. + /// dataflow, to allow removal of `TailLoops` with a bounded number of iterations.) pub fn default_for(h: &H, n: H::Node) -> PreserveNode { match h.get_optype(n) { OpType::CFG(_) | OpType::TailLoop(_) | OpType::Call(_) => PreserveNode::MustKeep, @@ -91,16 +96,33 @@ impl DeadCodeElimPass { self } - /// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code - /// used to evaluate these nodes. - /// [`HugrView::entrypoint`] is assumed to be an entry point; - /// for Module roots the client will want to mark some of the `FuncDefn` children - /// as entry points too. + /// Mark some nodes as reachable, i.e. so we cannot eliminate any code used to + /// evaluate their results. The [`HugrView::entrypoint`] is assumed to be reachable; + /// if that is the [`HugrView::module_root`], then any public [FuncDefn] and + /// [FuncDecl]s are also considered reachable by default, + /// but this can be change by [`Self::include_module_exports`]. + /// + /// [FuncDecl]: OpType::FuncDecl + /// [FuncDefn]: OpType::FuncDefn pub fn with_entry_points(mut self, entry_points: impl IntoIterator) -> Self { self.entry_points.extend(entry_points); self } + /// Sets whether the exported [FuncDefn](OpType::FuncDefn)s and + /// [FuncDecl](OpType::FuncDecl)s are considered reachable. + /// + /// Note that for non-module-entry Hugrs this has no effect, since we only remove + /// code beneath the entrypoint: this cannot be affected by other module children. + /// + /// So, for module-rooted-Hugrs: [VisPolicy::PublicIfModuleEntrypoint] is + /// equivalent to [VisPolicy::AllPublic]; and [VisPolicy::None] will remove + /// all children, unless some are explicity added by [Self::with_entry_points]. + pub fn include_module_exports(mut self, include: VisPolicy) -> Self { + self.include_exports = include; + self + } + fn find_needed_nodes(&self, h: &H) -> HashSet { let mut must_preserve = HashMap::new(); let mut needed = HashSet::new(); @@ -111,19 +133,23 @@ impl DeadCodeElimPass { continue; } for ch in h.children(n) { - if self.must_preserve(h, &mut must_preserve, ch) - || matches!( - h.get_optype(ch), + let must_keep = match h.get_optype(ch) { OpType::Case(_) // Include all Cases in Conditionals | OpType::DataflowBlock(_) // and all Basic Blocks in CFGs | OpType::ExitBlock(_) | OpType::AliasDecl(_) // and all Aliases (we do not track their uses in types) | OpType::AliasDefn(_) | OpType::Input(_) // Also Dataflow input/output, these are necessary for legality - | OpType::Output(_) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges - // (from Call/LoadConst/LoadFunction): - ) - { + | OpType::Output(_) => true, + // FuncDefns (as children of Module) only if public and including exports + // (will be included if static predecessors of Call/LoadFunction below, + // regardless of Visibility or self.include_exports) + OpType::FuncDefn(fd) => fd.visibility() == &Visibility::Public && self.include_exports.for_hugr(h), + OpType::FuncDecl(fd) => fd.visibility() == &Visibility::Public && self.include_exports.for_hugr(h), + // No Const, unless reached along static edges + _ => false + }; + if must_keep || self.must_preserve(h, &mut must_preserve, ch) { q.push_back(ch); } } @@ -141,7 +167,6 @@ impl DeadCodeElimPass { if let Some(res) = cache.get(&n) { return *res; } - #[allow(deprecated)] let res = match self.preserve_callback.as_ref()(h, n) { PreserveNode::MustKeep => true, PreserveNode::CanRemoveIgnoringChildren => false, @@ -174,18 +199,57 @@ impl ComposablePass for DeadCodeElimPass { mod test { use std::sync::Arc; - use hugr_core::Hugr; - use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder}; + use hugr_core::builder::{ + CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, + }; use hugr_core::extension::prelude::{ConstUsize, usize_t}; - use hugr_core::ops::{OpTag, OpTrait, handle::NodeHandle}; - use hugr_core::types::Signature; - use hugr_core::{HugrView, ops::Value, type_row}; + use hugr_core::ops::{OpTag, OpTrait, Value, handle::NodeHandle}; + use hugr_core::{Hugr, HugrView, type_row, types::Signature}; use itertools::Itertools; + use rstest::rstest; - use crate::ComposablePass; + use crate::{ComposablePass, VisPolicy}; use super::{DeadCodeElimPass, PreserveNode}; + #[rstest] + #[case(false, VisPolicy::None, true)] + #[case(false, VisPolicy::PublicIfModuleEntrypoint, false)] + #[case(false, VisPolicy::AllPublic, false)] + #[case(true, VisPolicy::None, true)] + #[case(true, VisPolicy::PublicIfModuleEntrypoint, false)] + #[case(true, VisPolicy::AllPublic, false)] + fn test_module_exports( + #[case] include_dfn: bool, + #[case] module_exports: VisPolicy, + #[case] decl_removed: bool, + ) { + let mut mb = ModuleBuilder::new(); + let dfn = mb + .define_function("foo", Signature::new_endo(usize_t())) + .unwrap(); + let ins = dfn.input_wires(); + let dfn = dfn.finish_with_outputs(ins).unwrap(); + let dcl = mb + .declare("bar", Signature::new_endo(usize_t()).into()) + .unwrap(); + let mut h = mb.finish_hugr().unwrap(); + let mut dce = DeadCodeElimPass::::default().include_module_exports(module_exports); + if include_dfn { + dce = dce.with_entry_points([dfn.node()]); + } + dce.run(&mut h).unwrap(); + let defn_retained = include_dfn; + let decl_retained = !decl_removed; + let children = h.children(h.module_root()).collect_vec(); + assert_eq!(defn_retained, children.iter().contains(&dfn.node())); + assert_eq!(decl_retained, children.iter().contains(&dcl.node())); + assert_eq!( + children.len(), + (defn_retained as usize) + (decl_retained as usize) + ); + } + #[test] fn test_cfg_callback() { let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap(); diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index 69ae288623..37f7ab4197 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -3,14 +3,14 @@ use std::collections::HashSet; use hugr_core::{ - HugrView, Node, + HugrView, Node, Visibility, hugr::hugrmut::HugrMut, ops::{OpTag, OpTrait}, }; use petgraph::visit::{Dfs, Walker}; use crate::{ - ComposablePass, + ComposablePass, VisPolicy, composable::{ValidatePassError, validate_if_test}, }; @@ -51,6 +51,7 @@ fn reachable_funcs<'a, H: HugrView>( /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { entry_points: Vec, + include_exports: VisPolicy, } impl RemoveDeadFuncsPass { @@ -66,6 +67,13 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } + + /// Sets whether the exported [FuncDefn](hugr_core::ops::FuncDefn) children are + /// included as entry points for reachability analysis - see [VisPolicy]. + pub fn include_module_exports(mut self, include: VisPolicy) -> Self { + self.include_exports = include; + self + } } impl> ComposablePass for RemoveDeadFuncsPass { @@ -73,6 +81,14 @@ impl> ComposablePass for RemoveDeadFuncsPass { type Result = (); fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { let mut entry_points = Vec::new(); + if self.include_exports.for_hugr(hugr) { + entry_points.extend(hugr.children(hugr.module_root()).filter(|ch| { + hugr.get_optype(*ch) + .as_func_defn() + .is_some_and(|fd| fd.visibility() == &Visibility::Public) + })); + } + for &n in self.entry_points.iter() { if !hugr.get_optype(n).is_func_defn() { return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n }); @@ -116,6 +132,9 @@ impl> ComposablePass for RemoveDeadFuncsPass { /// Note that for a [`Module`]-rooted Hugr with no `entry_points` provided, this will remove /// all functions from the module. /// +/// Note that, unlike [`DeadCodeElimPass`], this can remove functions *outside* the +/// [HugrView::entrypoint]. +/// /// # Errors /// * If any node in `entry_points` is not a [`FuncDefn`] /// @@ -123,61 +142,96 @@ impl> ComposablePass for RemoveDeadFuncsPass { /// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn /// [`LoadFunction`]: hugr_core::ops::OpType::LoadFunction /// [`Module`]: hugr_core::ops::OpType::Module +/// [`DeadCodeElimPass`]: super::DeadCodeElimPass +#[deprecated( // TODO When removing, rename remove_dead_funcs2 over this + note = "Does not account for visibility; use remove_dead_funcs2 or manually configure RemoveDeadFuncsPass" +)] pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, ) -> Result<(), ValidatePassError> { validate_if_test( - RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + RemoveDeadFuncsPass::default() + .include_module_exports(VisPolicy::None) + .with_module_entry_points(entry_points), h, ) } +/// Deletes from the Hugr any functions that are not used by either [`Call`] or +/// [`LoadFunction`] nodes in parts reachable from the entrypoint or public +/// [`FuncDefn`] children thereof. That is, +/// +/// * If the [HugrView::entrypoint] is the module root, then any [`FuncDefn`] children +/// with [Visibility::Public] will be considered reachable; +/// * otherwise, the [HugrView::entrypoint] itself will. +/// +/// Note that, unlike [`DeadCodeElimPass`], this can remove functions *outside* the +/// [HugrView::entrypoint]. +/// +/// [`Call`]: hugr_core::ops::OpType::Call +/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn +/// [`LoadFunction`]: hugr_core::ops::OpType::LoadFunction +/// [`Module`]: hugr_core::ops::OpType::Module +/// [`DeadCodeElimPass`]: super::DeadCodeElimPass +pub fn remove_dead_funcs2( + h: &mut impl HugrMut, +) -> Result<(), ValidatePassError> { + validate_if_test(RemoveDeadFuncsPass::default(), h) +} + #[cfg(test)] mod test { use std::collections::HashMap; - use hugr_core::ops::handle::NodeHandle; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use hugr_core::hugr::hugrmut::HugrMut; - use hugr_core::{HugrView, extension::prelude::usize_t, types::Signature}; + use hugr_core::ops::handle::NodeHandle; + use hugr_core::{HugrView, Visibility, extension::prelude::usize_t, types::Signature}; - use super::remove_dead_funcs; + use super::RemoveDeadFuncsPass; + use crate::{ComposablePass, VisPolicy}; #[rstest] - #[case(false, [], vec![])] // No entry_points removes everything! - #[case(true, [], vec!["from_main", "main"])] - #[case(false, ["main"], vec!["from_main", "main"])] - #[case(false, ["from_main"], vec!["from_main"])] - #[case(false, ["other1"], vec!["other1", "other2"])] - #[case(true, ["other2"], vec!["from_main", "main", "other2"])] - #[case(false, ["other1", "other2"], vec!["other1", "other2"])] + #[case(false, VisPolicy::default(), [], vec!["from_pub", "pubfunc"])] + #[case(false, VisPolicy::None, ["ment"], vec!["from_ment", "ment"])] + #[case(false, VisPolicy::None, ["from_ment", "from_pub"], vec!["from_ment", "from_pub"])] + #[case(false, VisPolicy::default(), ["from_ment"], vec!["from_ment", "from_pub", "pubfunc"])] + #[case(false, VisPolicy::AllPublic, ["ment"], vec!["from_ment", "from_pub", "ment", "pubfunc"])] + #[case(true, VisPolicy::default(), [], vec!["from_ment", "ment"])] + #[case(true, VisPolicy::AllPublic, [], vec!["from_ment", "from_pub", "ment", "pubfunc"])] + #[case(true, VisPolicy::None, ["from_pub"], vec!["from_ment", "from_pub", "ment"])] fn remove_dead_funcs_entry_points( #[case] use_hugr_entrypoint: bool, + #[case] inc: VisPolicy, #[case] entry_points: impl IntoIterator, #[case] retained_funcs: Vec<&'static str>, ) -> Result<(), Box> { let mut hb = ModuleBuilder::new(); - let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?; + let o2 = hb.define_function("from_pub", Signature::new_endo(usize_t()))?; let o2inp = o2.input_wires(); let o2 = o2.finish_with_outputs(o2inp)?; - let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?; + let mut o1 = hb.define_function_vis( + "pubfunc", + Signature::new_endo(usize_t()), + Visibility::Public, + )?; let o1c = o1.call(o2.handle(), &[], o1.input_wires())?; o1.finish_with_outputs(o1c.outputs())?; - let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?; + let fm = hb.define_function("from_ment", Signature::new_endo(usize_t()))?; let f_inp = fm.input_wires(); let fm = fm.finish_with_outputs(f_inp)?; - let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?; - let m_in = m.input_wires(); - let mut dfg = m.dfg_builder(Signature::new_endo(usize_t()), m_in)?; - let c = dfg.call(fm.handle(), &[], dfg.input_wires())?; - let dfg = dfg.finish_with_outputs(c.outputs()).unwrap(); - m.finish_with_outputs(dfg.outputs())?; + + let mut me = hb.define_function("ment", Signature::new_endo(usize_t()))?; + let mut dfg = me.dfg_builder(Signature::new_endo(usize_t()), me.input_wires())?; + let mc = dfg.call(fm.handle(), &[], dfg.input_wires())?; + let dfg = dfg.finish_with_outputs(mc.outputs()).unwrap(); + me.finish_with_outputs(dfg.outputs())?; let mut hugr = hb.finish_hugr()?; if use_hugr_entrypoint { @@ -193,14 +247,16 @@ mod test { }) .collect::>(); - remove_dead_funcs( - &mut hugr, - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .unwrap(); + RemoveDeadFuncsPass::default() + .include_module_exports(inc) + .with_module_entry_points( + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .run(&mut hugr) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 70b887a40c..7d4e3faf02 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -8,10 +8,16 @@ pub mod dataflow; pub mod dead_code; pub use dead_code::DeadCodeElimPass; mod dead_funcs; -pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs}; +#[deprecated( + note = "Does not account for visibility; use remove_dead_funcs2 or manually configure RemoveDeadFuncsPass" +)] +#[allow(deprecated)] // When original removed, rename remove_dead_funcs2=>remove_dead_funcs +pub use dead_funcs::remove_dead_funcs; +pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs2}; pub mod force_order; mod half_node; pub mod linearize_array; +use hugr_core::HugrView; pub use linearize_array::LinearizeArrayPass; pub mod lower; pub mod merge_bbs; @@ -27,3 +33,37 @@ pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; pub use untuple::UntuplePass; + +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +/// A policy for selecting [FuncDefn] and [FuncDecl]s using their [Visibility], +/// e.g. (typically) to use as starting points for analysis +/// +/// [FuncDefn]: hugr_core::ops::FuncDefn +/// [FuncDecl]: hugr_core::ops::FuncDecl +/// [Visibility]: hugr_core::Visibility +pub enum VisPolicy { + /// All [Public] functions should be used + /// + /// [Public]: hugr_core::Visibility::Public + AllPublic, + /// Do not select any functions + None, + /// Use the [Public] functions if the Hugr's [entrypoint] is the [module_root], + /// otherwise do not use any. + /// + /// [Public]: hugr_core::Visibility::Public + /// [entrypoint]: hugr_core::HugrView::entrypoint + /// [module_root]: hugr_core::HugrView::module_root + #[default] + PublicIfModuleEntrypoint, +} + +impl VisPolicy { + /// Returns whether to include the public functions of a particular Hugr + fn for_hugr(&self, h: &impl HugrView) -> bool { + matches!( + (self, h.entrypoint() == h.module_root()), + (Self::AllPublic, _) | (Self::PublicIfModuleEntrypoint, true) + ) + } +} diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 2d5abd5eb1..d20117a527 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -140,11 +140,12 @@ fn instantiate( Entry::Vacant(ve) => ve, }; - let name = mangle_name( - h.get_optype(poly_func).as_func_defn().unwrap().func_name(), - &type_args, + let defn = h.get_optype(poly_func).as_func_defn().unwrap(); + let name = mangle_name(defn.func_name(), &type_args); + let mono_tgt = h.add_node_after( + poly_func, + FuncDefn::new_vis(name, mono_sig, defn.visibility().clone()), ); - let mono_tgt = h.add_node_after(poly_func, FuncDefn::new(name, mono_sig)); // Insert BEFORE we scan (in case of recursion), hence we cannot use Entry::or_insert ve.insert(mono_tgt); // Now make the instantiation @@ -281,13 +282,13 @@ mod test { HugrBuilder, ModuleBuilder, }; use hugr_core::extension::prelude::{ConstUsize, UnpackTuple, UnwrapBuilder, usize_t}; - use hugr_core::ops::handle::{FuncID, NodeHandle}; + use hugr_core::ops::handle::FuncID; use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum}; - use hugr_core::{Hugr, HugrView, Node}; + use hugr_core::{Hugr, HugrView, Node, Visibility}; use rstest::rstest; - use crate::{monomorphize, remove_dead_funcs}; + use crate::{monomorphize, remove_dead_funcs2}; use super::{is_polymorphic, mangle_name}; @@ -349,9 +350,13 @@ mod test { let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?; fb.finish_with_outputs(trip.outputs())? }; - let mn = { + { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; - let mut fb = mb.define_function("main", Signature::new(usize_t(), outs))?; + let mut fb = mb.define_function_vis( + "main", + Signature::new(usize_t(), outs), + Visibility::Public, + )?; let [elem] = fb.input_wires_arr(); let [res1] = fb .call(tr.handle(), &[usize_t().into()], [elem])? @@ -394,12 +399,12 @@ mod test { assert_eq!(mono2, mono); // Idempotent let mut nopoly = mono; - remove_dead_funcs(&mut nopoly, [mn.node()])?; + remove_dead_funcs2(&mut nopoly)?; let mut funcs = list_funcs(&nopoly); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); for n in expected_mangled_names { - assert!(funcs.remove(&n).is_some()); + assert!(funcs.remove(&n).is_some(), "Did not find {n}"); } assert_eq!(funcs.keys().collect_vec(), vec![&"main"]); Ok(()) @@ -581,7 +586,7 @@ mod test { }; monomorphize(&mut hugr).unwrap(); - remove_dead_funcs(&mut hugr, []).unwrap(); + remove_dead_funcs2(&mut hugr).unwrap(); let funcs = list_funcs(&hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));