Skip to content
Merged
142 changes: 142 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,52 @@ impl PyDiGraph {
let edge = self.graph.add_edge(p_index, c_index, edge);
Ok(edge.index())
}

fn insert_between(
&mut self,
py: Python,
node: usize,
node_between: usize,
direction: bool,
) -> PyResult<()> {
let dir = if direction {
petgraph::Direction::Outgoing
} else {
petgraph::Direction::Incoming
};
let index = NodeIndex::new(node);
let node_between_index = NodeIndex::new(node_between);
let edges: Vec<(NodeIndex, EdgeIndex, PyObject)> = self
.graph
.edges_directed(node_between_index, dir)
.map(|edge| {
if direction {
(edge.target(), edge.id(), edge.weight().clone_ref(py))
} else {
(edge.source(), edge.id(), edge.weight().clone_ref(py))
}
})
.collect::<Vec<(NodeIndex, EdgeIndex, PyObject)>>();
for (other_index, edge_index, weight) in edges {
if direction {
self._add_edge(
node_between_index,
index,
weight.clone_ref(py),
)?;
self._add_edge(index, other_index, weight.clone_ref(py))?;
} else {
self._add_edge(other_index, index, weight.clone_ref(py))?;
self._add_edge(
index,
node_between_index,
weight.clone_ref(py),
)?;
}
self.graph.remove_edge(edge_index);
}
Ok(())
}
}

#[pymethods]
Expand Down Expand Up @@ -864,6 +910,102 @@ impl PyDiGraph {
Ok(())
}

/// Insert a node between a list of reference nodes and all their predecessors
///
/// This essentially iterates over all edges into the reference node
/// specified in the ``ref_nodes`` parameter removes those edges and then
/// adds 2 edges, one from the predecessor of ``ref_node`` to ``node``
/// and the other from ``node`` to ``ref_node``. The edge payloads for
/// the newly created edges are copied by reference from the original
/// edge that gets removed.
///
/// :param int node: The node index to insert between
/// :param int ref_node: The reference node index to insert ``node``
/// between
#[text_signature = "(node, ref_nodes, /)"]
pub fn insert_node_on_in_edges_multiple(
&mut self,
py: Python,
node: usize,
ref_nodes: Vec<usize>,
) -> PyResult<()> {
for ref_node in ref_nodes {
self.insert_between(py, node, ref_node, false)?;
}
Ok(())
}

/// Insert a node between a list of reference nodes and all their successors
///
/// This essentially iterates over all edges out of the reference node
/// specified in the ``ref_node`` parameter removes those edges and then
/// adds 2 edges, one from ``ref_node`` to ``node`` and the other from
/// ``node`` to the successor of ``ref_node``. The edge payloads for the
/// newly created edges are copied by reference from the original edge that
/// gets removed.
///
/// :param int node: The node index to insert between
/// :param int ref_nodes: The list of node indices to insert ``node``
/// between
#[text_signature = "(node, ref_nodes, /)"]
pub fn insert_node_on_out_edges_multiple(
&mut self,
py: Python,
node: usize,
ref_nodes: Vec<usize>,
) -> PyResult<()> {
for ref_node in ref_nodes {
self.insert_between(py, node, ref_node, true)?;
}
Ok(())
}

/// Insert a node between a reference node and all its predecessor nodes
///
/// This essentially iterates over all edges into the reference node
/// specified in the ``ref_node`` parameter removes those edges and then
/// adds 2 edges, one from the predecessor of ``ref_node`` to ``node`` and
/// the other from ``node`` to ``ref_node``. The edge payloads for the
/// newly created edges are copied by reference from the original edge that
/// gets removed.
///
/// :param int node: The node index to insert between
/// :param int ref_node: The reference node index to insert ``node``
/// between
#[text_signature = "(node, ref_node, /)"]
pub fn insert_node_on_in_edges(
&mut self,
py: Python,
node: usize,
ref_node: usize,
) -> PyResult<()> {
self.insert_between(py, node, ref_node, false)?;
Ok(())
}

/// Insert a node between a reference node and all its successor nodes
///
/// This essentially iterates over all edges out of the reference node
/// specified in the ``ref_node`` parameter removes those edges and then
/// adds 2 edges, one from ``ref_node`` to ``node`` and the other from
/// ``node`` to the successor of ``ref_node``. The edge payloads for the
/// newly created edges are copied by reference from the original edge
/// that gets removed.
///
/// :param int node: The node index to insert between
/// :param int ref_node: The reference node index to insert ``node``
/// between
#[text_signature = "(node, ref_node, /)"]
pub fn insert_node_on_out_edges(
&mut self,
py: Python,
node: usize,
ref_node: usize,
) -> PyResult<()> {
self.insert_between(py, node, ref_node, true)?;
Ok(())
}

/// Remove an edge between 2 nodes.
///
/// Note if there are multiple edges between the specified nodes only one
Expand Down
138 changes: 138 additions & 0 deletions tests/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,141 @@ def test_extend_from_weighted_edge_list_nodes_exist(self):
self.assertEqual(1, dag.out_degree(1))
self.assertEqual(1, dag.out_degree(2))
self.assertEqual(2, dag.in_degree(3))

def test_insert_node_on_in_edges(self):
graph = retworkx.PyDiGraph()
in_node = graph.add_node('qr[0]')
out_node = graph.add_child(in_node, 'qr[0]', 'qr[0]')
h_gate = graph.add_node('h')
graph.insert_node_on_in_edges(h_gate, out_node)
self.assertEqual(
[(in_node, h_gate, 'qr[0]'), (h_gate, out_node, 'qr[0]')],
graph.weighted_edge_list())

def test_insert_node_on_in_edges_multiple(self):
graph = retworkx.PyDiGraph()
in_node_0 = graph.add_node('qr[0]')
out_node_0 = graph.add_child(in_node_0, 'qr[0]', 'qr[0]')
in_node_1 = graph.add_node('qr[1]')
out_node_1 = graph.add_child(in_node_1, 'qr[1]', 'qr[1]')
cx_gate = graph.add_node('cx')
graph.insert_node_on_in_edges_multiple(cx_gate,
[out_node_0, out_node_1])
self.assertEqual(
{(in_node_0, cx_gate, 'qr[0]'), (cx_gate, out_node_0, 'qr[0]'),
(in_node_1, cx_gate, 'qr[1]'), (cx_gate, out_node_1, 'qr[1]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_in_edges_double(self):
graph = retworkx.PyDiGraph()
in_node = graph.add_node('qr[0]')
out_node = graph.add_child(in_node, 'qr[0]', 'qr[0]')
h_gate = graph.add_node('h')
z_gate = graph.add_node('z')
graph.insert_node_on_in_edges(h_gate, out_node)
graph.insert_node_on_in_edges(z_gate, out_node)
self.assertEqual(
{(in_node, h_gate, 'qr[0]'), (h_gate, z_gate, 'qr[0]'),
(z_gate, out_node, 'qr[0]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_in_edges_multiple_double(self):
graph = retworkx.PyDiGraph()
in_node_0 = graph.add_node('qr[0]')
out_node_0 = graph.add_child(in_node_0, 'qr[0]', 'qr[0]')
in_node_1 = graph.add_node('qr[1]')
out_node_1 = graph.add_child(in_node_1, 'qr[1]', 'qr[1]')
cx_gate = graph.add_node('cx')
cz_gate = graph.add_node('cz')
graph.insert_node_on_in_edges_multiple(cx_gate,
[out_node_0, out_node_1])
graph.insert_node_on_in_edges_multiple(cz_gate,
[out_node_0, out_node_1])
self.assertEqual(
{(in_node_0, cx_gate, 'qr[0]'), (cx_gate, cz_gate, 'qr[0]'),
(in_node_1, cx_gate, 'qr[1]'), (cx_gate, cz_gate, 'qr[1]'),
(cz_gate, out_node_0, 'qr[0]'), (cz_gate, out_node_1, 'qr[1]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_out_edges(self):
graph = retworkx.PyDiGraph()
in_node = graph.add_node('qr[0]')
out_node = graph.add_child(in_node, 'qr[0]', 'qr[0]')
h_gate = graph.add_node('h')
graph.insert_node_on_out_edges(h_gate, in_node)
self.assertEqual(
{(in_node, h_gate, 'qr[0]'), (h_gate, out_node, 'qr[0]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_out_edges_multiple(self):
graph = retworkx.PyDiGraph()
in_node_0 = graph.add_node('qr[0]')
out_node_0 = graph.add_child(in_node_0, 'qr[0]', 'qr[0]')
in_node_1 = graph.add_node('qr[1]')
out_node_1 = graph.add_child(in_node_1, 'qr[1]', 'qr[1]')
cx_gate = graph.add_node('cx')
graph.insert_node_on_out_edges_multiple(cx_gate,
[in_node_0, in_node_1])
self.assertEqual(
{(in_node_0, cx_gate, 'qr[0]'), (cx_gate, out_node_0, 'qr[0]'),
(in_node_1, cx_gate, 'qr[1]'), (cx_gate, out_node_1, 'qr[1]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_out_edges_double(self):
graph = retworkx.PyDiGraph()
in_node = graph.add_node('qr[0]')
out_node = graph.add_child(in_node, 'qr[0]', 'qr[0]')
h_gate = graph.add_node('h')
z_gate = graph.add_node('z')
graph.insert_node_on_out_edges(h_gate, in_node)
graph.insert_node_on_out_edges(z_gate, in_node)
self.assertEqual(
{(in_node, z_gate, 'qr[0]'), (z_gate, h_gate, 'qr[0]'),
(h_gate, out_node, 'qr[0]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_out_edges_multiple_double(self):
graph = retworkx.PyDiGraph()
in_node_0 = graph.add_node('qr[0]')
out_node_0 = graph.add_child(in_node_0, 'qr[0]', 'qr[0]')
in_node_1 = graph.add_node('qr[1]')
out_node_1 = graph.add_child(in_node_1, 'qr[1]', 'qr[1]')
cx_gate = graph.add_node('cx')
cz_gate = graph.add_node('cz')
graph.insert_node_on_out_edges_multiple(cx_gate,
[in_node_0, in_node_1])
graph.insert_node_on_out_edges_multiple(cz_gate,
[in_node_0, in_node_1])
self.assertEqual(
{(in_node_0, cz_gate, 'qr[0]'), (cz_gate, cx_gate, 'qr[0]'),
(in_node_1, cz_gate, 'qr[1]'), (cz_gate, cx_gate, 'qr[1]'),
(cx_gate, out_node_0, 'qr[0]'), (cx_gate, out_node_1, 'qr[1]')},
set(graph.weighted_edge_list()))

def test_insert_node_on_in_edges_no_edges(self):
graph = retworkx.PyDiGraph()
node_a = graph.add_node(None)
node_b = graph.add_node(None)
graph.insert_node_on_in_edges(node_b, node_a)
self.assertEqual([], graph.edge_list())

def test_insert_node_on_in_edges_multiple_no_edges(self):
graph = retworkx.PyDiGraph()
node_a = graph.add_node(None)
node_b = graph.add_node(None)
graph.insert_node_on_in_edges_multiple(node_b, [node_a])
self.assertEqual([], graph.edge_list())

def test_insert_node_on_out_edges_no_edges(self):
graph = retworkx.PyDiGraph()
node_a = graph.add_node(None)
node_b = graph.add_node(None)
graph.insert_node_on_out_edges(node_b, node_a)
self.assertEqual([], graph.edge_list())

def test_insert_node_on_out_edges_multiple_no_edges(self):
graph = retworkx.PyDiGraph()
node_a = graph.add_node(None)
node_b = graph.add_node(None)
graph.insert_node_on_out_edges_multiple(node_b, [node_a])
self.assertEqual([], graph.edge_list())