Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Algorithm Functions
retworkx.floyd_warshall
retworkx.graph_floyd_warshall_numpy
retworkx.digraph_floyd_warshall_numpy
retworkx.collect_runs
retworkx.layers
retworkx.digraph_adjacency_matrix
retworkx.graph_adjacency_matrix
Expand Down
71 changes: 71 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,76 @@ fn digraph_floyd_warshall_numpy(
Ok(mat.into_pyarray(py).into())
}

/// Collect runs that match a filter function
///
/// A run is a path of nodes where there is only a single successor and all
/// nodes in the path match the given condition. Each node in the graph can
/// appear in only a single run.
///
/// :param PyDiGraph graph: The graph to find runs in
/// :param filter_fn: The filter function to use for matching nodes. It takes
/// in one argument, the node data payload/weight object, and will return a
/// boolean whether the node matches the conditions or not. If it returns
/// ``False`` it will skip that node.
///
/// :returns: a list of runs, where each run is a list of node data
/// payload/weight for the nodes in the run
/// :rtype: list
#[pyfunction]
#[text_signature = "(graph, filter)"]
fn collect_runs(
py: Python,
graph: &digraph::PyDiGraph,
filter_fn: PyObject,
) -> PyResult<Vec<Vec<PyObject>>> {
let mut out_list: Vec<Vec<PyObject>> = Vec::new();
let mut seen: HashSet<NodeIndex> = HashSet::new();

let filter_node = |node: &PyObject| -> PyResult<bool> {
let res = filter_fn.call1(py, (node,))?;
Ok(res.extract(py)?)
};

let nodes = match algo::toposort(graph, None) {
Ok(nodes) => nodes,
Err(_err) => {
return Err(DAGHasCycle::new_err("Sort encountered a cycle"))
}
};
for node in nodes {
if !filter_node(&graph.graph[node])? || seen.contains(&node) {
continue;
}
seen.insert(node);
let mut group: Vec<PyObject> = vec![graph.graph[node].clone_ref(py)];
let mut successors: Vec<NodeIndex> = graph
.graph
.neighbors_directed(node, petgraph::Direction::Outgoing)
.collect();
successors.dedup();

while successors.len() == 1
&& filter_node(&graph.graph[successors[0]])?
&& !seen.contains(&successors[0])
{
group.push(graph.graph[successors[0]].clone_ref(py));
seen.insert(successors[0]);
successors = graph
.graph
.neighbors_directed(
successors[0],
petgraph::Direction::Outgoing,
)
.collect();
successors.dedup();
}
if !group.is_empty() {
out_list.push(group);
}
}
Ok(out_list)
}

/// Return a list of layers
///
/// A layer is a subgraph whose nodes are disjoint, i.e.,
Expand Down Expand Up @@ -2545,6 +2615,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(floyd_warshall))?;
m.add_wrapped(wrap_pyfunction!(graph_floyd_warshall_numpy))?;
m.add_wrapped(wrap_pyfunction!(digraph_floyd_warshall_numpy))?;
m.add_wrapped(wrap_pyfunction!(collect_runs))?;
m.add_wrapped(wrap_pyfunction!(layers))?;
m.add_wrapped(wrap_pyfunction!(graph_distance_matrix))?;
m.add_wrapped(wrap_pyfunction!(digraph_distance_matrix))?;
Expand Down
138 changes: 138 additions & 0 deletions tests/test_collect_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import retworkx


class TestCollectRuns(unittest.TestCase):
def test_dagcircuit_basic(self):
dag = retworkx.PyDAG()
qr_0_in = dag.add_node("qr[0]")
qr_0_out = dag.add_node("qr[0]")
qr_1_in = dag.add_node("qr[1]")
qr_1_out = dag.add_node("qr[1]")
cr_0_in = dag.add_node("cr[0]")
cr_0_out = dag.add_node("cr[0]")
cr_1_in = dag.add_node("cr[1]")
cr_1_out = dag.add_node("cr[1]")

h_gate = dag.add_child(qr_0_in, "h", "qr[0]")
x_gate = dag.add_child(h_gate, "x", "qr[0]")
cx_gate = dag.add_child(x_gate, "cx", "qr[0]")
dag.add_edge(qr_1_in, cx_gate, "qr[1]")

measure_qr_1 = dag.add_child(cx_gate, "measure", "qr[1]")
Comment thread
mtreinish marked this conversation as resolved.
dag.add_edge(cr_1_in, measure_qr_1, "cr[1]")
x_gate = dag.add_child(measure_qr_1, "x", "qr[1]")
dag.add_edge(measure_qr_1, x_gate, "cr[1]")
dag.add_edge(cr_0_in, x_gate, "cr[0]")

measure_qr_0 = dag.add_child(cx_gate, "measure", "qr[0]")
dag.add_edge(measure_qr_0, qr_0_out, "qr[0]")
dag.add_edge(measure_qr_0, cr_0_out, "cr[0]")
dag.add_edge(x_gate, measure_qr_0, "cr[0]")

measure_qr_1_out = dag.add_child(x_gate, "measure", "cr[1]")
dag.add_edge(x_gate, measure_qr_1_out, "qr[1]")
dag.add_edge(measure_qr_1_out, qr_1_out, "qr[1]")
dag.add_edge(measure_qr_1_out, cr_1_out, "cr[1]")

def filter_function(node):
return node in ['h', 'x']

res = retworkx.collect_runs(dag, filter_function)
expected = [['h', 'x'], ['x']]
self.assertEqual(expected, res)

def test_multiple_successor_edges(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
cx_1 = dag.add_child(q0, 'cx', 'q0')
dag.add_edge(q1, cx_1, 'q1')
cx_2 = dag.add_child(cx_1, 'cx', 'q0')
dag.add_edge(q1, cx_2, 'q1')
cx_3 = dag.add_child(cx_2, 'cx', 'q0')
dag.add_edge(q1, cx_3, 'q1')

def filter_function(node):
return node == 'cx'

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['cx', 'cx', 'cx']], res)

def test_cycle(self):
dag = retworkx.PyDiGraph()
dag.extend_from_edge_list([(0, 1), (1, 2), (2, 0)])
with self.assertRaises(retworkx.DAGHasCycle):
retworkx.collect_runs(dag, lambda _: True)

def test_filter_function_inner_exception(self):
dag = retworkx.PyDiGraph()
dag.add_node('a')
dag.add_child(0, 'b', None)

def filter_function(node):
raise IndexError("Things fail from time to time")

with self.assertRaises(IndexError):
retworkx.collect_runs(dag, filter_function)

def test_empty(self):
dag = retworkx.PyDAG()
self.assertEqual([], retworkx.collect_runs(dag, lambda _: True))

def test_h_h_cx(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
h_1 = dag.add_child(q0, 'h', 'q0')
h_2 = dag.add_child(q1, 'h', 'q1')
cx_2 = dag.add_child(h_1, 'cx', 'q0')
dag.add_edge(h_2, cx_2, 'q1')

def filter_function(node):
return node in ['cx', 'h']

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['h', 'cx'], ['h']], res)

def test_cx_h_h_cx(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
cx_1 = dag.add_child(q0, 'cx', 'q0')
dag.add_edge(q1, cx_1, 'q1')
h_1 = dag.add_child(cx_1, 'h', 'q0')
h_2 = dag.add_child(cx_1, 'h', 'q1')
cx_2 = dag.add_child(h_1, 'cx', 'q0')
dag.add_edge(h_2, cx_2, 'q1')

def filter_function(node):
return node in ['cx', 'h']

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['cx'], ['h', 'cx'], ['h']], res)

def test_cx_h_cx(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
cx_1 = dag.add_child(q0, 'cx', 'q0')
dag.add_edge(q1, cx_1, 'q1')
h_1 = dag.add_child(cx_1, 'h', 'q0')
cx_2 = dag.add_child(h_1, 'cx', 'q0')
dag.add_edge(cx_1, cx_2, 'q1')

def filter_function(node):
return node in ['cx', 'h']

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['cx'], ['h', 'cx']], res)