diff --git a/docs/source/api.rst b/docs/source/api.rst index c7e0630fed..0d4e4168b8 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -107,4 +107,5 @@ Return Iterator Types .. autosummary:: :toctree: stubs - retworkx.BFSSuccessors \ No newline at end of file + retworkx.BFSSuccessors + retworkx.NodeIndices diff --git a/src/digraph.rs b/src/digraph.rs index 4347c71829..cf0e18c40f 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -41,6 +41,7 @@ use petgraph::visit::{ }; use super::dot_utils::build_dot; +use super::iterators::NodeIndices; use super::{ is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, NoSuitableNeighbors, NodesRemoved, @@ -510,10 +511,12 @@ impl PyDiGraph { /// Return a list of all node indexes. /// /// :returns: A list of all the node indexes in the graph - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "()"] - pub fn node_indexes(&self) -> Vec { - self.graph.node_indices().map(|node| node.index()).collect() + pub fn node_indexes(&self) -> NodeIndices { + NodeIndices { + nodes: self.graph.node_indices().map(|node| node.index()).collect(), + } } /// Return True if there is an edge from node_a to node_b. @@ -1327,13 +1330,16 @@ impl PyDiGraph { /// :param int node: The index of the node to get the neighbors of /// /// :returns: A list of the neighbor node indicies - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "(node, /)"] - pub fn neighbors(&self, node: usize) -> Vec { - self.graph - .neighbors(NodeIndex::new(node)) - .map(|node| node.index()) - .collect() + pub fn neighbors(&self, node: usize) -> NodeIndices { + NodeIndices { + nodes: self + .graph + .neighbors(NodeIndex::new(node)) + .map(|node| node.index()) + .collect(), + } } /// Get the successor indices of a node. @@ -1344,16 +1350,19 @@ impl PyDiGraph { /// :param int node: The index of the node to get the successors of /// /// :returns: A list of the neighbor node indicies - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "(node, /)"] - pub fn successor_indices(&mut self, node: usize) -> Vec { - self.graph - .neighbors_directed( - NodeIndex::new(node), - petgraph::Direction::Outgoing, - ) - .map(|node| node.index()) - .collect() + pub fn successor_indices(&mut self, node: usize) -> NodeIndices { + NodeIndices { + nodes: self + .graph + .neighbors_directed( + NodeIndex::new(node), + petgraph::Direction::Outgoing, + ) + .map(|node| node.index()) + .collect(), + } } /// Get the predecessor indices of a node. @@ -1364,16 +1373,19 @@ impl PyDiGraph { /// :param int node: The index of the node to get the predecessors of /// /// :returns: A list of the neighbor node indicies - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "(node, /)"] - pub fn predecessor_indices(&mut self, node: usize) -> Vec { - self.graph - .neighbors_directed( - NodeIndex::new(node), - petgraph::Direction::Incoming, - ) - .map(|node| node.index()) - .collect() + pub fn predecessor_indices(&mut self, node: usize) -> NodeIndices { + NodeIndices { + nodes: self + .graph + .neighbors_directed( + NodeIndex::new(node), + petgraph::Direction::Incoming, + ) + .map(|node| node.index()) + .collect(), + } } /// Get the index and edge data for all parents of a node. /// @@ -1424,15 +1436,15 @@ impl PyDiGraph { /// as new nodes /// /// :returns: A list of int indices of the newly created nodes - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "(obj_list, /)"] - pub fn add_nodes_from(&mut self, obj_list: Vec) -> Vec { + pub fn add_nodes_from(&mut self, obj_list: Vec) -> NodeIndices { let mut out_list: Vec = Vec::new(); for obj in obj_list { let node_index = self.graph.add_node(obj); out_list.push(node_index.index()); } - out_list + NodeIndices { nodes: out_list } } /// Remove nodes from the graph. diff --git a/src/graph.rs b/src/graph.rs index e5c7176f01..6387c52801 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -27,7 +27,9 @@ use pyo3::types::{PyDict, PyList, PyLong, PyString, PyTuple}; use pyo3::Python; use super::dot_utils::build_dot; +use super::iterators::NodeIndices; use super::{NoEdgeBetweenNodes, NodesRemoved}; + use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::stable_graph::StableUnGraph; @@ -355,10 +357,12 @@ impl PyGraph { /// Return a list of all node indexes. /// /// :returns: A list of all the node indexes in the graph - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "()"] - pub fn node_indexes(&self) -> Vec { - self.graph.node_indices().map(|node| node.index()).collect() + pub fn node_indexes(&self) -> NodeIndices { + NodeIndices { + nodes: self.graph.node_indices().map(|node| node.index()).collect(), + } } /// Return True if there is an edge between node_a to node_b. @@ -719,15 +723,15 @@ impl PyGraph { /// :param list obj_list: A list of python object to attach to the graph. /// /// :returns indices: A list of int indices of the newly created nodes - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "(obj_list, /)"] - pub fn add_nodes_from(&mut self, obj_list: Vec) -> Vec { + pub fn add_nodes_from(&mut self, obj_list: Vec) -> NodeIndices { let mut out_list: Vec = Vec::new(); for obj in obj_list { let node_index = self.graph.add_node(obj); out_list.push(node_index.index()); } - out_list + NodeIndices { nodes: out_list } } /// Remove nodes from the graph. @@ -783,13 +787,16 @@ impl PyGraph { /// :param int node: The index of the node to get the neibhors of /// /// :returns: A list of the neighbor node indicies - /// :rtype: list + /// :rtype: NodeIndices #[text_signature = "(node, /)"] - pub fn neighbors(&self, node: usize) -> Vec { - self.graph - .neighbors(NodeIndex::new(node)) - .map(|node| node.index()) - .collect() + pub fn neighbors(&self, node: usize) -> NodeIndices { + NodeIndices { + nodes: self + .graph + .neighbors(NodeIndex::new(node)) + .map(|node| node.index()) + .collect(), + } } /// Get the degree for a node diff --git a/src/iterators.rs b/src/iterators.rs index bdcd463bfe..cb045a5f93 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -108,3 +108,78 @@ impl PySequenceProtocol for BFSSuccessors { } } } + +/// A custom class for the return of node indices +/// +/// This class is a container class for the results of functions that +/// return a list of node indices. It implements the Python sequence +/// protocol. So you can treat the return as a read-only sequence/list +/// that is integer indexed. If you want to use it as an iterator you +/// can by wrapping it in an ``iter()`` that will yield the results in +/// order. +/// +/// For example:: +/// +/// import retworkx +/// +/// graph = retworkx.generators.directed_path_graph(5) +/// nodes = retworkx.node_indexes(0) +/// # Index based access +/// third_element = nodes[2] +/// # Use as iterator +/// nodes_iter = iter(node) +/// first_element = next(nodes_iter) +/// second_element = next(nodes_iter) +/// +#[pyclass(module = "retworkx")] +pub struct NodeIndices { + pub nodes: Vec, +} + +#[pyproto] +impl<'p> PyObjectProtocol<'p> for NodeIndices { + fn __richcmp__( + &self, + other: &'p PySequence, + op: pyo3::basic::CompareOp, + ) -> PyResult { + let compare = |other: &PySequence| -> PyResult { + if other.len()? as usize != self.nodes.len() { + return Ok(false); + } + for i in 0..self.nodes.len() { + let other_raw = other.get_item(i.try_into().unwrap())?; + let other_value: usize = other_raw.extract()?; + if other_value != self.nodes[i] { + return Ok(false); + } + } + Ok(true) + }; + match op { + pyo3::basic::CompareOp::Eq => compare(other), + pyo3::basic::CompareOp::Ne => match compare(other) { + Ok(res) => Ok(!res), + Err(err) => Err(err), + }, + _ => Err(PyNotImplementedError::new_err( + "Comparison not implemented", + )), + } + } +} + +#[pyproto] +impl PySequenceProtocol for NodeIndices { + fn __len__(&self) -> PyResult { + Ok(self.nodes.len()) + } + + fn __getitem__(&'p self, idx: isize) -> PyResult { + if idx >= self.nodes.len().try_into().unwrap() { + Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) + } else { + Ok(self.nodes[idx as usize]) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 28a41af6e5..631ed46cf4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,7 @@ use rand_pcg::Pcg64; use rayon::prelude::*; use crate::generators::PyInit_generators; +use crate::iterators::NodeIndices; trait NodesRemoved { fn nodes_removed(&self) -> bool; @@ -113,14 +114,16 @@ fn longest_path(graph: &digraph::PyDiGraph) -> PyResult> { /// object must be a DAG without a cycle. /// /// :returns: The node indices of the longest path on the DAG -/// :rtype: list +/// :rtype: NodeIndices /// /// :raises Exception: If an unexpected error occurs or a path can't be found /// :raises DAGHasCycle: If the input PyDiGraph has a cycle #[pyfunction] #[text_signature = "(graph, /)"] -fn dag_longest_path(graph: &digraph::PyDiGraph) -> PyResult> { - longest_path(graph) +fn dag_longest_path(graph: &digraph::PyDiGraph) -> PyResult { + Ok(NodeIndices { + nodes: longest_path(graph)?, + }) } /// Find the length of the longest path in a DAG @@ -336,19 +339,21 @@ fn is_isomorphic_node_match( /// :param PyDiGraph graph: The DAG to get the topological sort on /// /// :returns: A list of node indices topologically sorted. -/// :rtype: list +/// :rtype: NodeIndices /// /// :raises DAGHasCycle: if a cycle is encountered while sorting the graph #[pyfunction] #[text_signature = "(graph, /)"] -fn topological_sort(graph: &digraph::PyDiGraph) -> PyResult> { +fn topological_sort(graph: &digraph::PyDiGraph) -> PyResult { let nodes = match algo::toposort(graph, None) { Ok(nodes) => nodes, Err(_err) => { return Err(DAGHasCycle::new_err("Sort encountered a cycle")) } }; - Ok(nodes.iter().map(|node| node.index()).collect()) + Ok(NodeIndices { + nodes: nodes.iter().map(|node| node.index()).collect(), + }) } fn dfs_edges(graph: G, source: Option) -> Vec<(usize, usize)> @@ -1808,7 +1813,7 @@ fn digraph_dijkstra_shortest_path_lengths( /// /// :returns: The computed shortest path between node and finish as a list /// of node indices. -/// :rtype: list +/// :rtype: NodeIndices #[pyfunction] #[text_signature = "(graph, node, goal_fn, edge_cost, estimate_cost, /)"] fn graph_astar_shortest_path( @@ -1818,7 +1823,7 @@ fn graph_astar_shortest_path( goal_fn: PyObject, edge_cost_fn: PyObject, estimate_cost_fn: PyObject, -) -> PyResult> { +) -> PyResult { let goal_fn_callable = |a: &PyObject| -> PyResult { let res = goal_fn.call1(py, (a,))?; let raw = res.to_object(py); @@ -1858,7 +1863,9 @@ fn graph_astar_shortest_path( )) } }; - Ok(path.1.into_iter().map(|x| x.index()).collect()) + Ok(NodeIndices { + nodes: path.1.into_iter().map(|x| x.index()).collect(), + }) } /// Compute the A* shortest path for a PyDiGraph @@ -1880,7 +1887,7 @@ fn graph_astar_shortest_path( /// /// :return: The computed shortest path between node and finish as a list /// of node indices. -/// :rtype: list +/// :rtype: NodeIndices #[pyfunction] #[text_signature = "(graph, node, goal_fn, edge_cost, estimate_cost, /)"] fn digraph_astar_shortest_path( @@ -1890,7 +1897,7 @@ fn digraph_astar_shortest_path( goal_fn: PyObject, edge_cost_fn: PyObject, estimate_cost_fn: PyObject, -) -> PyResult> { +) -> PyResult { let goal_fn_callable = |a: &PyObject| -> PyResult { let res = goal_fn.call1(py, (a,))?; let raw = res.to_object(py); @@ -1930,7 +1937,9 @@ fn digraph_astar_shortest_path( )) } }; - Ok(path.1.into_iter().map(|x| x.index()).collect()) + Ok(NodeIndices { + nodes: path.1.into_iter().map(|x| x.index()).collect(), + }) } /// Return a :math:`G_{np}` directed random graph, also known as an @@ -2551,6 +2560,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(generators))?; Ok(()) } diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index ba5165a58c..a5ce01bac9 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -15,7 +15,7 @@ import retworkx -class TestCustomReturnTypeComparisons(unittest.TestCase): +class TestBFSSuccessorsComparisons(unittest.TestCase): def setUp(self): self.dag = retworkx.PyDAG() @@ -62,3 +62,41 @@ def test__ne__invalid_type(self): def test__gt__not_implemented(self): with self.assertRaises(NotImplementedError): retworkx.bfs_successors(self.dag, 0) > [('b', ['c'])] + + +class TestNodeIndicesComparisons(unittest.TestCase): + + def setUp(self): + self.dag = retworkx.PyDAG() + node_a = self.dag.add_node('a') + self.dag.add_child(node_a, 'b', "Edgy") + + def test__eq__match(self): + self.assertTrue(self.dag.node_indexes() == [0, 1]) + + def test__eq__not_match(self): + self.assertFalse(self.dag.node_indexes() == [1, 2]) + + def test__eq__different_length(self): + self.assertFalse(self.dag.node_indexes() == [0, 1, 2, 3]) + + def test__eq__invalid_type(self): + with self.assertRaises(TypeError): + self.dag.node_indexes() == ['a', None] + + def test__ne__match(self): + self.assertFalse(self.dag.node_indexes() != [0, 1]) + + def test__ne__not_match(self): + self.assertTrue(self.dag.node_indexes() != [1, 2]) + + def test__ne__different_length(self): + self.assertTrue(self.dag.node_indexes() != [0, 1, 2, 3]) + + def test__ne__invalid_type(self): + with self.assertRaises(TypeError): + self.dag.node_indexes() != ['a', None] + + def test__gt__not_implemented(self): + with self.assertRaises(NotImplementedError): + self.dag.node_indexes() > [2, 1]