diff --git a/docs/source/api.rst b/docs/source/api.rst index 5742633035..0ac5fcdb73 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -66,6 +66,7 @@ Algorithm Functions retworkx.floyd_warshall retworkx.graph_floyd_warshall_numpy retworkx.digraph_floyd_warshall_numpy + retworkx.collect_runs retworkx.layers retworkx.digraph_adjacency_matrix retworkx.graph_adjacency_matrix diff --git a/src/lib.rs b/src/lib.rs index 3579568d84..5304bdfd7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1084,6 +1084,76 @@ fn digraph_floyd_warshall_numpy( Ok(mat.into_pyarray(py).into()) } +/// Collect runs that match a filter function +/// +/// A run is a path of nodes where there is only a single successor and all +/// nodes in the path match the given condition. Each node in the graph can +/// appear in only a single run. +/// +/// :param PyDiGraph graph: The graph to find runs in +/// :param filter_fn: The filter function to use for matching nodes. It takes +/// in one argument, the node data payload/weight object, and will return a +/// boolean whether the node matches the conditions or not. If it returns +/// ``False`` it will skip that node. +/// +/// :returns: a list of runs, where each run is a list of node data +/// payload/weight for the nodes in the run +/// :rtype: list +#[pyfunction] +#[text_signature = "(graph, filter)"] +fn collect_runs( + py: Python, + graph: &digraph::PyDiGraph, + filter_fn: PyObject, +) -> PyResult>> { + let mut out_list: Vec> = Vec::new(); + let mut seen: HashSet = HashSet::new(); + + let filter_node = |node: &PyObject| -> PyResult { + let res = filter_fn.call1(py, (node,))?; + Ok(res.extract(py)?) + }; + + let nodes = match algo::toposort(graph, None) { + Ok(nodes) => nodes, + Err(_err) => { + return Err(DAGHasCycle::new_err("Sort encountered a cycle")) + } + }; + for node in nodes { + if !filter_node(&graph.graph[node])? || seen.contains(&node) { + continue; + } + seen.insert(node); + let mut group: Vec = vec![graph.graph[node].clone_ref(py)]; + let mut successors: Vec = graph + .graph + .neighbors_directed(node, petgraph::Direction::Outgoing) + .collect(); + successors.dedup(); + + while successors.len() == 1 + && filter_node(&graph.graph[successors[0]])? + && !seen.contains(&successors[0]) + { + group.push(graph.graph[successors[0]].clone_ref(py)); + seen.insert(successors[0]); + successors = graph + .graph + .neighbors_directed( + successors[0], + petgraph::Direction::Outgoing, + ) + .collect(); + successors.dedup(); + } + if !group.is_empty() { + out_list.push(group); + } + } + Ok(out_list) +} + /// Return a list of layers /// /// A layer is a subgraph whose nodes are disjoint, i.e., @@ -2547,6 +2617,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(floyd_warshall))?; m.add_wrapped(wrap_pyfunction!(graph_floyd_warshall_numpy))?; m.add_wrapped(wrap_pyfunction!(digraph_floyd_warshall_numpy))?; + m.add_wrapped(wrap_pyfunction!(collect_runs))?; m.add_wrapped(wrap_pyfunction!(layers))?; m.add_wrapped(wrap_pyfunction!(graph_distance_matrix))?; m.add_wrapped(wrap_pyfunction!(digraph_distance_matrix))?; diff --git a/tests/test_collect_runs.py b/tests/test_collect_runs.py new file mode 100644 index 0000000000..6f2585f992 --- /dev/null +++ b/tests/test_collect_runs.py @@ -0,0 +1,138 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import retworkx + + +class TestCollectRuns(unittest.TestCase): + def test_dagcircuit_basic(self): + dag = retworkx.PyDAG() + qr_0_in = dag.add_node("qr[0]") + qr_0_out = dag.add_node("qr[0]") + qr_1_in = dag.add_node("qr[1]") + qr_1_out = dag.add_node("qr[1]") + cr_0_in = dag.add_node("cr[0]") + cr_0_out = dag.add_node("cr[0]") + cr_1_in = dag.add_node("cr[1]") + cr_1_out = dag.add_node("cr[1]") + + h_gate = dag.add_child(qr_0_in, "h", "qr[0]") + x_gate = dag.add_child(h_gate, "x", "qr[0]") + cx_gate = dag.add_child(x_gate, "cx", "qr[0]") + dag.add_edge(qr_1_in, cx_gate, "qr[1]") + + measure_qr_1 = dag.add_child(cx_gate, "measure", "qr[1]") + dag.add_edge(cr_1_in, measure_qr_1, "cr[1]") + x_gate = dag.add_child(measure_qr_1, "x", "qr[1]") + dag.add_edge(measure_qr_1, x_gate, "cr[1]") + dag.add_edge(cr_0_in, x_gate, "cr[0]") + + measure_qr_0 = dag.add_child(cx_gate, "measure", "qr[0]") + dag.add_edge(measure_qr_0, qr_0_out, "qr[0]") + dag.add_edge(measure_qr_0, cr_0_out, "cr[0]") + dag.add_edge(x_gate, measure_qr_0, "cr[0]") + + measure_qr_1_out = dag.add_child(x_gate, "measure", "cr[1]") + dag.add_edge(x_gate, measure_qr_1_out, "qr[1]") + dag.add_edge(measure_qr_1_out, qr_1_out, "qr[1]") + dag.add_edge(measure_qr_1_out, cr_1_out, "cr[1]") + + def filter_function(node): + return node in ['h', 'x'] + + res = retworkx.collect_runs(dag, filter_function) + expected = [['h', 'x'], ['x']] + self.assertEqual(expected, res) + + def test_multiple_successor_edges(self): + dag = retworkx.PyDiGraph() + q0, q1 = dag.add_nodes_from(['q0', 'q1']) + cx_1 = dag.add_child(q0, 'cx', 'q0') + dag.add_edge(q1, cx_1, 'q1') + cx_2 = dag.add_child(cx_1, 'cx', 'q0') + dag.add_edge(q1, cx_2, 'q1') + cx_3 = dag.add_child(cx_2, 'cx', 'q0') + dag.add_edge(q1, cx_3, 'q1') + + def filter_function(node): + return node == 'cx' + + res = retworkx.collect_runs(dag, filter_function) + self.assertEqual([['cx', 'cx', 'cx']], res) + + def test_cycle(self): + dag = retworkx.PyDiGraph() + dag.extend_from_edge_list([(0, 1), (1, 2), (2, 0)]) + with self.assertRaises(retworkx.DAGHasCycle): + retworkx.collect_runs(dag, lambda _: True) + + def test_filter_function_inner_exception(self): + dag = retworkx.PyDiGraph() + dag.add_node('a') + dag.add_child(0, 'b', None) + + def filter_function(node): + raise IndexError("Things fail from time to time") + + with self.assertRaises(IndexError): + retworkx.collect_runs(dag, filter_function) + + def test_empty(self): + dag = retworkx.PyDAG() + self.assertEqual([], retworkx.collect_runs(dag, lambda _: True)) + + def test_h_h_cx(self): + dag = retworkx.PyDiGraph() + q0, q1 = dag.add_nodes_from(['q0', 'q1']) + h_1 = dag.add_child(q0, 'h', 'q0') + h_2 = dag.add_child(q1, 'h', 'q1') + cx_2 = dag.add_child(h_1, 'cx', 'q0') + dag.add_edge(h_2, cx_2, 'q1') + + def filter_function(node): + return node in ['cx', 'h'] + + res = retworkx.collect_runs(dag, filter_function) + self.assertEqual([['h', 'cx'], ['h']], res) + + def test_cx_h_h_cx(self): + dag = retworkx.PyDiGraph() + q0, q1 = dag.add_nodes_from(['q0', 'q1']) + cx_1 = dag.add_child(q0, 'cx', 'q0') + dag.add_edge(q1, cx_1, 'q1') + h_1 = dag.add_child(cx_1, 'h', 'q0') + h_2 = dag.add_child(cx_1, 'h', 'q1') + cx_2 = dag.add_child(h_1, 'cx', 'q0') + dag.add_edge(h_2, cx_2, 'q1') + + def filter_function(node): + return node in ['cx', 'h'] + + res = retworkx.collect_runs(dag, filter_function) + self.assertEqual([['cx'], ['h', 'cx'], ['h']], res) + + def test_cx_h_cx(self): + dag = retworkx.PyDiGraph() + q0, q1 = dag.add_nodes_from(['q0', 'q1']) + cx_1 = dag.add_child(q0, 'cx', 'q0') + dag.add_edge(q1, cx_1, 'q1') + h_1 = dag.add_child(cx_1, 'h', 'q0') + cx_2 = dag.add_child(h_1, 'cx', 'q0') + dag.add_edge(cx_1, cx_2, 'q1') + + def filter_function(node): + return node in ['cx', 'h'] + + res = retworkx.collect_runs(dag, filter_function) + self.assertEqual([['cx'], ['h', 'cx']], res)