diff --git a/src/digraph.rs b/src/digraph.rs index afee542121..86769e5cb7 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -31,6 +31,8 @@ use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::stable_graph::StableDiGraph; +use petgraph::stable_graph::StableUnGraph; + use petgraph::visit::{ GetAdjacencyMatrix, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoEdgesDirected, IntoNeighbors, IntoNeighborsDirected, @@ -1605,6 +1607,37 @@ impl PyDiGraph { } edges.is_empty() } + + /// Generate a new PyGraph object from this graph + /// + /// This will create a new :class:`~retworkx.PyGraph` object from this + /// graph. All edges in this graph will be created as undirected edges in + /// the new graph object. + /// Do note that the node and edge weights/data payloads will be passed + /// by reference to the new :class:`~retworkx.PyGraph` object. + /// + /// :returns: A new PyGraph object with an undirected edge for every + /// directed edge in this graph + /// :rtype: PyGraph + pub fn to_undirected(&self, py: Python) -> crate::graph::PyGraph { + let mut new_graph = StableUnGraph::::default(); + let mut node_map: HashMap = HashMap::new(); + for node_index in self.graph.node_indices() { + let node = self.graph[node_index].clone_ref(py); + let new_index = new_graph.add_node(node); + node_map.insert(node_index, new_index); + } + for edge in self.edge_references() { + let source = node_map.get(&edge.source()).unwrap(); + let target = node_map.get(&edge.target()).unwrap(); + let weight = edge.weight().clone_ref(py); + new_graph.add_edge(*source, *target, weight); + } + crate::graph::PyGraph { + graph: new_graph, + node_removed: false, + } + } } #[pyproto] diff --git a/tests/test_to_undirected.py b/tests/test_to_undirected.py new file mode 100644 index 0000000000..d7c2ab8672 --- /dev/null +++ b/tests/test_to_undirected.py @@ -0,0 +1,55 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import retworkx + + +class TestToUndirected(unittest.TestCase): + + def test_to_undirected_empty_graph(self): + digraph = retworkx.PyDiGraph() + graph = digraph.to_undirected() + self.assertEqual(0, len(graph)) + + def test_single_direction_graph(self): + digraph = retworkx.generators.directed_path_graph(5) + graph = digraph.to_undirected() + self.assertEqual(digraph.weighted_edge_list(), + graph.weighted_edge_list()) + + def test_bidirectional_graph(self): + digraph = retworkx.generators.directed_path_graph(5) + for i in range(0, 4): + digraph.add_edge(i + 1, i, None) + graph = digraph.to_undirected() + self.assertEqual(digraph.weighted_edge_list(), + graph.weighted_edge_list()) + + def test_shared_ref(self): + digraph = retworkx.PyDiGraph() + node_weight = {'a': 1} + node_a = digraph.add_node(node_weight) + edge_weight = {'a': 1} + digraph.add_child(node_a, 'b', edge_weight) + graph = digraph.to_undirected() + self.assertEqual(digraph[node_a], {'a': 1}) + self.assertEqual(graph[node_a], {'a': 1}) + node_weight['b'] = 2 + self.assertEqual(digraph[node_a], {'a': 1, 'b': 2}) + self.assertEqual(graph[node_a], {'a': 1, 'b': 2}) + self.assertEqual(digraph.get_edge_data(0, 1), {'a': 1}) + self.assertEqual(graph.get_edge_data(0, 1), {'a': 1}) + edge_weight['b'] = 2 + self.assertEqual(digraph.get_edge_data(0, 1), {'a': 1, 'b': 2}) + self.assertEqual(graph.get_edge_data(0, 1), {'a': 1, 'b': 2})