From 208b1d0b43cd1d3479ee1a52a45490cfb2a2f1ac Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 2 Oct 2020 17:00:15 -0400 Subject: [PATCH 1/3] Add to_undirected method for PyDiGraph This commit adds a new method to the PyDiGraph class, to_undirected(), which will generate an undirected PyGraph object from the PyDiGraph object. Fixes #153 --- src/digraph.rs | 33 ++++++++++++++++++++++ tests/test_to_undirected.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/test_to_undirected.py diff --git a/src/digraph.rs b/src/digraph.rs index f52cf7284b..3c8dd634b0 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -30,6 +30,8 @@ use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::stable_graph::StableDiGraph; +use petgraph::stable_graph::StableUnGraph; + use petgraph::visit::{ GetAdjacencyMatrix, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoEdgesDirected, IntoNeighbors, IntoNeighborsDirected, @@ -1536,6 +1538,37 @@ impl PyDiGraph { } edges.is_empty() } + + /// Generate a new PyGraph object from this graph + /// + /// This will create a new :class:`~retworkx.PyGraph` object from this + /// graph. All edges in this graph will be created as undirected edges in + /// the new graph obejct. + /// Do note that the node and edge weights/data payloads will be passed + /// by reference to the new :class:`~retworkx.PyGraph` object. + /// + /// :returns: A new PyGraph object with an undirected edge for every + /// directed edge in this graph + /// :rtype: PyGraph + pub fn to_undirected(&self, py: Python) -> crate::graph::PyGraph { + let mut new_graph = StableUnGraph::::default(); + let mut node_map: HashMap = HashMap::new(); + for node_index in self.graph.node_indices() { + let node = self.graph[node_index].clone_ref(py); + let new_index = new_graph.add_node(node); + node_map.insert(node_index, new_index); + } + for edge in self.edge_references() { + let source = node_map.get(&edge.source()).unwrap(); + let target = node_map.get(&edge.target()).unwrap(); + let weight = edge.weight().clone_ref(py); + new_graph.add_edge(*source, *target, weight); + } + crate::graph::PyGraph { + graph: new_graph, + node_removed: false, + } + } } #[pyproto] diff --git a/tests/test_to_undirected.py b/tests/test_to_undirected.py new file mode 100644 index 0000000000..23118fb35c --- /dev/null +++ b/tests/test_to_undirected.py @@ -0,0 +1,55 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import retworkx + + +class TestToUndirected(unittest.TestCase): + + def test_to_undirected_empty_graph(self): + digraph = retworkx.PyDiGraph() + graph = digraph.to_undirected() + self.assertEqual(0, len(graph)) + + def test_single_direction_graph(self): + digraph = retworkx.generators.directed_path_graph(5) + graph = digraph.to_undirected() + self.assertEqual(digraph.weighted_edge_list(), + graph.weighted_edge_list()) + + def test_bidirectional_graph(self): + digraph = retworkx.generators.directed_path_graph(5) + for i in range(0, 4): + digraph.add_edge(i+1, i, None) + graph = digraph.to_undirected() + self.assertEqual(digraph.weighted_edge_list(), + graph.weighted_edge_list()) + + def test_shared_ref(self): + digraph = retworkx.PyDiGraph() + node_weight = {'a': 1} + node_a = digraph.add_node(node_weight) + edge_weight = {'a': 1} + digraph.add_child(node_a, 'b', edge_weight) + graph = digraph.to_undirected() + self.assertEqual(digraph[node_a], {'a': 1}) + self.assertEqual(graph[node_a], {'a': 1}) + node_weight['b'] = 2 + self.assertEqual(digraph[node_a], {'a': 1, 'b': 2}) + self.assertEqual(graph[node_a], {'a': 1, 'b': 2}) + self.assertEqual(digraph.get_edge_data(0, 1), {'a': 1}) + self.assertEqual(graph.get_edge_data(0, 1), {'a': 1}) + edge_weight['b'] = 2 + self.assertEqual(digraph.get_edge_data(0, 1), {'a': 1, 'b': 2}) + self.assertEqual(graph.get_edge_data(0, 1), {'a': 1, 'b': 2}) From fd3de11696395e4927cf0968f31bf967a5878c30 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 2 Oct 2020 17:05:50 -0400 Subject: [PATCH 2/3] Fix lint --- tests/test_to_undirected.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_to_undirected.py b/tests/test_to_undirected.py index 23118fb35c..d7c2ab8672 100644 --- a/tests/test_to_undirected.py +++ b/tests/test_to_undirected.py @@ -31,7 +31,7 @@ def test_single_direction_graph(self): def test_bidirectional_graph(self): digraph = retworkx.generators.directed_path_graph(5) for i in range(0, 4): - digraph.add_edge(i+1, i, None) + digraph.add_edge(i + 1, i, None) graph = digraph.to_undirected() self.assertEqual(digraph.weighted_edge_list(), graph.weighted_edge_list()) From de5e976e2418d74d0eb9bac4f43f8790de327e77 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 2 Nov 2020 07:06:21 -0500 Subject: [PATCH 3/3] Fix type in src/digraph.rs --- src/digraph.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/digraph.rs b/src/digraph.rs index af58d04b0e..17f01b05b2 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -1568,7 +1568,7 @@ impl PyDiGraph { /// /// This will create a new :class:`~retworkx.PyGraph` object from this /// graph. All edges in this graph will be created as undirected edges in - /// the new graph obejct. + /// the new graph object. /// Do note that the node and edge weights/data payloads will be passed /// by reference to the new :class:`~retworkx.PyGraph` object. ///