diff --git a/docs/source/api.rst b/docs/source/api.rst index 0d4e4168b8..5742633035 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -109,3 +109,5 @@ Return Iterator Types retworkx.BFSSuccessors retworkx.NodeIndices + retworkx.EdgeList + retworkx.WeightedEdgeList diff --git a/src/digraph.rs b/src/digraph.rs index 4e0c1fe783..4ca9bd8d93 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -41,7 +41,7 @@ use petgraph::visit::{ }; use super::dot_utils::build_dot; -use super::iterators::NodeIndices; +use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList}; use super::{ is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, NoSuitableNeighbors, NodesRemoved, @@ -661,11 +661,14 @@ impl PyDiGraph { /// ``source`` and ``target`` are the node indices. /// /// :returns: An edge list with weights - /// :rtype: list - pub fn edge_list(&self) -> Vec<(usize, usize)> { - self.edge_references() - .map(|edge| (edge.source().index(), edge.target().index())) - .collect() + /// :rtype: EdgeList + pub fn edge_list(&self) -> EdgeList { + EdgeList { + edges: self + .edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect(), + } } /// Get edge list with weights @@ -675,20 +678,20 @@ impl PyDiGraph { /// payload of the edge. /// /// :returns: An edge list with weights - /// :rtype: list - pub fn weighted_edge_list( - &self, - py: Python, - ) -> Vec<(usize, usize, PyObject)> { - self.edge_references() - .map(|edge| { - ( - edge.source().index(), - edge.target().index(), - edge.weight().clone_ref(py), - ) - }) - .collect() + /// :rtype: WeightedEdgeList + pub fn weighted_edge_list(&self, py: Python) -> WeightedEdgeList { + WeightedEdgeList { + edges: self + .edge_references() + .map(|edge| { + ( + edge.source().index(), + edge.target().index(), + edge.weight().clone_ref(py), + ) + }) + .collect(), + } } /// Remove a node from the graph. @@ -1160,21 +1163,22 @@ impl PyDiGraph { let mut edges_to_add: Vec<(usize, usize, PyObject)> = Vec::new(); for dir in &DIRECTIONS { - let edges; - if dir == &petgraph::Direction::Outgoing { - edges = self.out_edges(u); - } else { - edges = self.in_edges(u); - } - - for edge in edges { + for edge in self.graph.edges_directed(NodeIndex::new(u), *dir) { let s = edge.source(); let d = edge.target(); - if s == u { - edges_to_add.push((v, d, edge.weight().clone())); + if s.index() == u { + edges_to_add.push(( + v, + d.index(), + edge.weight().clone_ref(py), + )); } else { - edges_to_add.push((s, v, edge.weight().clone())); + edges_to_add.push(( + s.index(), + v, + edge.weight().clone_ref(py), + )); } } } @@ -1397,16 +1401,16 @@ impl PyDiGraph { /// /// :returns: A list of tuples of the form: /// ``(parent_index, node_index, edge_data)``` - /// :rtype: list + /// :rtype: WeightedEdgeList #[text_signature = "(self, node, /)"] - pub fn in_edges(&self, node: usize) -> Vec<(usize, usize, &PyObject)> { + pub fn in_edges(&self, py: Python, node: usize) -> WeightedEdgeList { let index = NodeIndex::new(node); let dir = petgraph::Direction::Incoming; let raw_edges = self.graph.edges_directed(index, dir); - let out_list: Vec<(usize, usize, &PyObject)> = raw_edges - .map(|x| (x.source().index(), node, x.weight())) + let out_list: Vec<(usize, usize, PyObject)> = raw_edges + .map(|x| (x.source().index(), node, x.weight().clone_ref(py))) .collect(); - out_list + WeightedEdgeList { edges: out_list } } /// Get the index and edge data for all children of a node. @@ -1418,16 +1422,16 @@ impl PyDiGraph { /// /// :returns out_edges: A list of tuples of the form: /// ```(node_index, child_index, edge_data)``` - /// :rtype: list + /// :rtype: WeightedEdgeList #[text_signature = "(self, node, /)"] - pub fn out_edges(&self, node: usize) -> Vec<(usize, usize, &PyObject)> { + pub fn out_edges(&self, py: Python, node: usize) -> WeightedEdgeList { let index = NodeIndex::new(node); let dir = petgraph::Direction::Outgoing; let raw_edges = self.graph.edges_directed(index, dir); - let out_list: Vec<(usize, usize, &PyObject)> = raw_edges - .map(|x| (node, x.target().index(), x.weight())) + let out_list: Vec<(usize, usize, PyObject)> = raw_edges + .map(|x| (node, x.target().index(), x.weight().clone_ref(py))) .collect(); - out_list + WeightedEdgeList { edges: out_list } } /// Add new nodes to the graph. diff --git a/src/graph.rs b/src/graph.rs index d6ab57f40b..4773fc1463 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -27,7 +27,7 @@ use pyo3::types::{PyDict, PyList, PyLong, PyString, PyTuple}; use pyo3::Python; use super::dot_utils::build_dot; -use super::iterators::NodeIndices; +use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList}; use super::{NoEdgeBetweenNodes, NodesRemoved}; use petgraph::graph::{EdgeIndex, NodeIndex}; @@ -465,12 +465,15 @@ impl PyGraph { /// ``source`` and ``target`` are the node indices. /// /// :returns: An edge list with weights - /// :rtype: list + /// :rtype: EdgeList #[text_signature = "(self)"] - pub fn edge_list(&self) -> Vec<(usize, usize)> { - self.edge_references() - .map(|edge| (edge.source().index(), edge.target().index())) - .collect() + pub fn edge_list(&self) -> EdgeList { + EdgeList { + edges: self + .edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect(), + } } /// Get edge list with weights @@ -480,21 +483,21 @@ impl PyGraph { /// payload of the edge. /// /// :returns: An edge list with weights - /// :rtype: list + /// :rtype: WeightedEdgeList #[text_signature = "(self)"] - pub fn weighted_edge_list( - &self, - py: Python, - ) -> Vec<(usize, usize, PyObject)> { - self.edge_references() - .map(|edge| { - ( - edge.source().index(), - edge.target().index(), - edge.weight().clone_ref(py), - ) - }) - .collect() + pub fn weighted_edge_list(&self, py: Python) -> WeightedEdgeList { + WeightedEdgeList { + edges: self + .edge_references() + .map(|edge| { + ( + edge.source().index(), + edge.target().index(), + edge.weight().clone_ref(py), + ) + }) + .collect(), + } } /// Remove a node from the graph. diff --git a/src/iterators.rs b/src/iterators.rs index cb045a5f93..0dc8dc074c 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -183,3 +183,160 @@ impl PySequenceProtocol for NodeIndices { } } } + +/// A custom class for the return of edge lists +/// +/// This class is a container class for the results of functions that +/// return a list of edges. It implements the Python sequence +/// protocol. So you can treat the return as a read-only sequence/list +/// that is integer indexed. If you want to use it as an iterator you +/// can by wrapping it in an ``iter()`` that will yield the results in +/// order. +/// +/// For example:: +/// +/// import retworkx +/// +/// graph = retworkx.generators.directed_path_graph(5) +/// edges = graph.edge_list() +/// # Index based access +/// third_element = edges[2] +/// # Use as iterator +/// edges_iter = iter(edges) +/// first_element = next(edges_iter) +/// second_element = next(edges_iter) +/// +#[pyclass(module = "retworkx")] +pub struct EdgeList { + pub edges: Vec<(usize, usize)>, +} + +#[pyproto] +impl<'p> PyObjectProtocol<'p> for EdgeList { + fn __richcmp__( + &self, + other: &'p PySequence, + op: pyo3::basic::CompareOp, + ) -> PyResult { + let compare = |other: &PySequence| -> PyResult { + if other.len()? as usize != self.edges.len() { + return Ok(false); + } + for i in 0..self.edges.len() { + let other_raw = other.get_item(i.try_into().unwrap())?; + let other_value: (usize, usize) = other_raw.extract()?; + if other_value != self.edges[i] { + return Ok(false); + } + } + Ok(true) + }; + match op { + pyo3::basic::CompareOp::Eq => compare(other), + pyo3::basic::CompareOp::Ne => match compare(other) { + Ok(res) => Ok(!res), + Err(err) => Err(err), + }, + _ => Err(PyNotImplementedError::new_err( + "Comparison not implemented", + )), + } + } +} + +#[pyproto] +impl PySequenceProtocol for EdgeList { + fn __len__(&self) -> PyResult { + Ok(self.edges.len()) + } + + fn __getitem__(&'p self, idx: isize) -> PyResult<(usize, usize)> { + if idx >= self.edges.len().try_into().unwrap() { + Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) + } else { + Ok(self.edges[idx as usize]) + } + } +} + +/// A custom class for the return of edge lists with weights +/// +/// This class is a container class for the results of functions that +/// return a list of edges with weights. It implements the Python sequence +/// protocol. So you can treat the return as a read-only sequence/list +/// that is integer indexed. If you want to use it as an iterator you +/// can by wrapping it in an ``iter()`` that will yield the results in +/// order. +/// +/// For example:: +/// +/// import retworkx +/// +/// graph = retworkx.generators.directed_path_graph(5) +/// edges = graph.weighted_edge_list() +/// # Index based access +/// third_element = edges[2] +/// # Use as iterator +/// edges_iter = iter(edges) +/// first_element = next(edges_iter) +/// second_element = next(edges_iter) +/// +#[pyclass(module = "retworkx")] +pub struct WeightedEdgeList { + pub edges: Vec<(usize, usize, PyObject)>, +} + +#[pyproto] +impl<'p> PyObjectProtocol<'p> for WeightedEdgeList { + fn __richcmp__( + &self, + other: &'p PySequence, + op: pyo3::basic::CompareOp, + ) -> PyResult { + let compare = |other: &PySequence| -> PyResult { + if other.len()? as usize != self.edges.len() { + return Ok(false); + } + let gil = Python::acquire_gil(); + let py = gil.python(); + for i in 0..self.edges.len() { + let other_raw = other.get_item(i.try_into().unwrap())?; + let other_value: (usize, usize, PyObject) = + other_raw.extract()?; + if other_value.0 != self.edges[i].0 + || other_value.1 != self.edges[i].1 + || self.edges[i].2.as_ref(py).compare(other_value.2)? + != std::cmp::Ordering::Equal + { + return Ok(false); + } + } + Ok(true) + }; + match op { + pyo3::basic::CompareOp::Eq => compare(other), + pyo3::basic::CompareOp::Ne => match compare(other) { + Ok(res) => Ok(!res), + Err(err) => Err(err), + }, + _ => Err(PyNotImplementedError::new_err( + "Comparison not implemented", + )), + } + } +} + +#[pyproto] +impl PySequenceProtocol for WeightedEdgeList { + fn __len__(&self) -> PyResult { + Ok(self.edges.len()) + } + + fn __getitem__(&'p self, idx: isize) -> PyResult<(usize, usize, PyObject)> { + if idx >= self.edges.len().try_into().unwrap() { + Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) + } else { + Ok(self.edges[idx as usize].clone()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 01cf43320b..3579568d84 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ use rand_pcg::Pcg64; use rayon::prelude::*; use crate::generators::PyInit_generators; -use crate::iterators::NodeIndices; +use crate::iterators::{EdgeList, NodeIndices}; trait NodesRemoved { fn nodes_removed(&self) -> bool; @@ -441,14 +441,16 @@ where /// /// :returns: A list of edges as a tuple of the form ``(source, target)`` in /// depth-first order -/// :rtype: list +/// :rtype: EdgeList #[pyfunction] #[text_signature = "(graph, /, source=None)"] fn digraph_dfs_edges( graph: &digraph::PyDiGraph, source: Option, -) -> Vec<(usize, usize)> { - dfs_edges(graph, source) +) -> EdgeList { + EdgeList { + edges: dfs_edges(graph, source), + } } /// Get edge list in depth first order @@ -462,13 +464,13 @@ fn digraph_dfs_edges( /// /// :returns: A list of edges as a tuple of the form ``(source, target)`` in /// depth-first order +/// :rtype: EdgeList #[pyfunction] #[text_signature = "(graph, /, source=None)"] -fn graph_dfs_edges( - graph: &graph::PyGraph, - source: Option, -) -> Vec<(usize, usize)> { - dfs_edges(graph, source) +fn graph_dfs_edges(graph: &graph::PyGraph, source: Option) -> EdgeList { + EdgeList { + edges: dfs_edges(graph, source), + } } /// Return successors in a breadth-first-search from a source node. @@ -2431,13 +2433,13 @@ pub fn strongly_connected_components( /// /// :returns: A list describing the cycle. The index of node ids which /// forms a cycle (loop) in the input graph -/// :rtype: list +/// :rtype: EdgeList #[pyfunction] #[text_signature = "(graph, /, source=None)"] pub fn digraph_find_cycle( graph: &digraph::PyDiGraph, source: Option, -) -> Vec<(usize, usize)> { +) -> EdgeList { let mut graph_nodes: HashSet = graph.graph.node_indices().collect(); let mut cycle: Vec<(usize, usize)> = Vec::new(); @@ -2483,7 +2485,7 @@ pub fn digraph_find_cycle( cycle.push((pred[&z].index(), z.index())); z = pred[&z]; } - return cycle; + return EdgeList { edges: cycle }; } //if an unexplored node is encountered if !visited.contains(&child) { @@ -2500,7 +2502,7 @@ pub fn digraph_find_cycle( visited.insert(z); } } - cycle + EdgeList { edges: cycle } } // The provided node is invalid. @@ -2574,6 +2576,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(generators))?; Ok(()) } diff --git a/src/union.rs b/src/union.rs index 628f5dc019..357cc46edf 100644 --- a/src/union.rs +++ b/src/union.rs @@ -62,7 +62,7 @@ pub fn digraph_union( node_map.insert(node.index(), node_index); } - for edge in b.weighted_edge_list(py) { + for edge in b.weighted_edge_list(py).edges { let source = edge.0; let target = edge.1; let edge_weight = edge.2; diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index a5ce01bac9..8f00a9125d 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -100,3 +100,77 @@ def test__ne__invalid_type(self): def test__gt__not_implemented(self): with self.assertRaises(NotImplementedError): self.dag.node_indexes() > [2, 1] + + +class TestEdgeListComparisons(unittest.TestCase): + + def setUp(self): + self.dag = retworkx.PyDAG() + node_a = self.dag.add_node('a') + self.dag.add_child(node_a, 'b', "Edgy") + + def test__eq__match(self): + self.assertTrue(self.dag.edge_list() == [(0, 1)]) + + def test__eq__not_match(self): + self.assertFalse(self.dag.edge_list() == [(1, 2)]) + + def test__eq__different_length(self): + self.assertFalse(self.dag.edge_list() == [(0, 1), (2, 3)]) + + def test__eq__invalid_type(self): + self.assertFalse(self.dag.edge_list() == ['a', None]) + + def test__ne__match(self): + self.assertFalse(self.dag.edge_list() != [(0, 1)]) + + def test__ne__not_match(self): + self.assertTrue(self.dag.edge_list() != [(1, 2)]) + + def test__ne__different_length(self): + self.assertTrue(self.dag.edge_list() != [(0, 1), (2, 3)]) + + def test__ne__invalid_type(self): + self.assertTrue(self.dag.edge_list() != ['a', None]) + + def test__gt__not_implemented(self): + with self.assertRaises(NotImplementedError): + self.dag.edge_list() > [(2, 1)] + + +class TestWeightedEdgeListComparisons(unittest.TestCase): + + def setUp(self): + self.dag = retworkx.PyDAG() + node_a = self.dag.add_node('a') + self.dag.add_child(node_a, 'b', "Edgy") + + def test__eq__match(self): + self.assertTrue(self.dag.weighted_edge_list() == [(0, 1, 'Edgy')]) + + def test__eq__not_match(self): + self.assertFalse(self.dag.weighted_edge_list() == [(1, 2, None)]) + + def test__eq__different_length(self): + self.assertFalse( + self.dag.weighted_edge_list() == [ + (0, 1, 'Edgy'), (2, 3, 'Not Edgy')]) + + def test__eq__invalid_type(self): + self.assertFalse(self.dag.weighted_edge_list() == ['a', None]) + + def test__ne__match(self): + self.assertFalse(self.dag.weighted_edge_list() != [(0, 1, 'Edgy')]) + + def test__ne__not_match(self): + self.assertTrue(self.dag.weighted_edge_list() != [(1, 2, 'Not Edgy')]) + + def test__ne__different_length(self): + self.assertTrue(self.dag.node_indexes() != [0, 1, 2, 3]) + + def test__ne__invalid_type(self): + self.assertTrue(self.dag.weighted_edge_list() != ['a', None]) + + def test__gt__not_implemented(self): + with self.assertRaises(NotImplementedError): + self.dag.weighted_edge_list() > [(2, 1, 'Not Edgy')]