Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Algorithm Functions
retworkx.floyd_warshall
retworkx.graph_floyd_warshall_numpy
retworkx.digraph_floyd_warshall_numpy
retworkx.collect_runs
retworkx.layers
retworkx.digraph_adjacency_matrix
retworkx.graph_adjacency_matrix
Expand Down
71 changes: 71 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,76 @@ fn digraph_floyd_warshall_numpy(
Ok(mat.into_pyarray(py).into())
}

/// Collect runs that match a filter function
///
/// A run is a path of nodes where there is only a single successor and all
/// nodes in the path match the given condition. Each node in the graph can
/// appear in only a single run.
///
/// :param PyDiGraph graph: The graph to find runs in
/// :param filter_fn: The filter function to use for matching nodes. It takes
/// in one argument, the node data payload/weight object, and will return a
/// boolean whether the node matches the conditions or not. If it returns
/// ``False`` it will skip that node.
///
/// :returns: a list of runs, where each run is a list of node data
/// payload/weight for the nodes in the run
/// :rtype: list
#[pyfunction]
#[text_signature = "(graph, filter)"]
fn collect_runs(
py: Python,
graph: &digraph::PyDiGraph,
filter_fn: PyObject,
) -> PyResult<Vec<Vec<PyObject>>> {
let mut out_list: Vec<Vec<PyObject>> = Vec::new();
let mut seen: HashSet<NodeIndex> = HashSet::new();

let filter_node = |node: &PyObject| -> PyResult<bool> {
let res = filter_fn.call1(py, (node,))?;
Ok(res.extract(py)?)
};

let nodes = match algo::toposort(graph, None) {
Ok(nodes) => nodes,
Err(_err) => {
return Err(DAGHasCycle::new_err("Sort encountered a cycle"))
}
};
for node in nodes {
if !filter_node(&graph.graph[node])? || seen.contains(&node) {
continue;
}
seen.insert(node);
let mut group: Vec<PyObject> = vec![graph.graph[node].clone_ref(py)];
let mut successors: Vec<NodeIndex> = graph
.graph
.neighbors_directed(node, petgraph::Direction::Outgoing)
.collect();
successors.dedup();

while successors.len() == 1
&& filter_node(&graph.graph[successors[0]])?
&& !seen.contains(&successors[0])
{
group.push(graph.graph[successors[0]].clone_ref(py));
seen.insert(successors[0]);
successors = graph
.graph
.neighbors_directed(
successors[0],
petgraph::Direction::Outgoing,
)
.collect();
successors.dedup();
}
if !group.is_empty() {
out_list.push(group);
}
}
Ok(out_list)
}

/// Return a list of layers
///
/// A layer is a subgraph whose nodes are disjoint, i.e.,
Expand Down Expand Up @@ -2547,6 +2617,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(floyd_warshall))?;
m.add_wrapped(wrap_pyfunction!(graph_floyd_warshall_numpy))?;
m.add_wrapped(wrap_pyfunction!(digraph_floyd_warshall_numpy))?;
m.add_wrapped(wrap_pyfunction!(collect_runs))?;
m.add_wrapped(wrap_pyfunction!(layers))?;
m.add_wrapped(wrap_pyfunction!(graph_distance_matrix))?;
m.add_wrapped(wrap_pyfunction!(digraph_distance_matrix))?;
Expand Down
138 changes: 138 additions & 0 deletions tests/test_collect_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import retworkx


class TestCollectRuns(unittest.TestCase):
def test_dagcircuit_basic(self):
dag = retworkx.PyDAG()
qr_0_in = dag.add_node("qr[0]")
qr_0_out = dag.add_node("qr[0]")
qr_1_in = dag.add_node("qr[1]")
qr_1_out = dag.add_node("qr[1]")
cr_0_in = dag.add_node("cr[0]")
cr_0_out = dag.add_node("cr[0]")
cr_1_in = dag.add_node("cr[1]")
cr_1_out = dag.add_node("cr[1]")

h_gate = dag.add_child(qr_0_in, "h", "qr[0]")
x_gate = dag.add_child(h_gate, "x", "qr[0]")
cx_gate = dag.add_child(x_gate, "cx", "qr[0]")
dag.add_edge(qr_1_in, cx_gate, "qr[1]")

measure_qr_1 = dag.add_child(cx_gate, "measure", "qr[1]")
Comment thread
mtreinish marked this conversation as resolved.
dag.add_edge(cr_1_in, measure_qr_1, "cr[1]")
x_gate = dag.add_child(measure_qr_1, "x", "qr[1]")
dag.add_edge(measure_qr_1, x_gate, "cr[1]")
dag.add_edge(cr_0_in, x_gate, "cr[0]")

measure_qr_0 = dag.add_child(cx_gate, "measure", "qr[0]")
dag.add_edge(measure_qr_0, qr_0_out, "qr[0]")
dag.add_edge(measure_qr_0, cr_0_out, "cr[0]")
dag.add_edge(x_gate, measure_qr_0, "cr[0]")

measure_qr_1_out = dag.add_child(x_gate, "measure", "cr[1]")
dag.add_edge(x_gate, measure_qr_1_out, "qr[1]")
dag.add_edge(measure_qr_1_out, qr_1_out, "qr[1]")
dag.add_edge(measure_qr_1_out, cr_1_out, "cr[1]")

def filter_function(node):
return node in ['h', 'x']

res = retworkx.collect_runs(dag, filter_function)
expected = [['h', 'x'], ['x']]
self.assertEqual(expected, res)

def test_multiple_successor_edges(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
cx_1 = dag.add_child(q0, 'cx', 'q0')
dag.add_edge(q1, cx_1, 'q1')
cx_2 = dag.add_child(cx_1, 'cx', 'q0')
dag.add_edge(q1, cx_2, 'q1')
cx_3 = dag.add_child(cx_2, 'cx', 'q0')
dag.add_edge(q1, cx_3, 'q1')

def filter_function(node):
return node == 'cx'

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['cx', 'cx', 'cx']], res)

def test_cycle(self):
dag = retworkx.PyDiGraph()
dag.extend_from_edge_list([(0, 1), (1, 2), (2, 0)])
with self.assertRaises(retworkx.DAGHasCycle):
retworkx.collect_runs(dag, lambda _: True)

def test_filter_function_inner_exception(self):
dag = retworkx.PyDiGraph()
dag.add_node('a')
dag.add_child(0, 'b', None)

def filter_function(node):
raise IndexError("Things fail from time to time")

with self.assertRaises(IndexError):
retworkx.collect_runs(dag, filter_function)

def test_empty(self):
dag = retworkx.PyDAG()
self.assertEqual([], retworkx.collect_runs(dag, lambda _: True))

def test_h_h_cx(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
h_1 = dag.add_child(q0, 'h', 'q0')
h_2 = dag.add_child(q1, 'h', 'q1')
cx_2 = dag.add_child(h_1, 'cx', 'q0')
dag.add_edge(h_2, cx_2, 'q1')

def filter_function(node):
return node in ['cx', 'h']

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['h', 'cx'], ['h']], res)

def test_cx_h_h_cx(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
cx_1 = dag.add_child(q0, 'cx', 'q0')
dag.add_edge(q1, cx_1, 'q1')
h_1 = dag.add_child(cx_1, 'h', 'q0')
h_2 = dag.add_child(cx_1, 'h', 'q1')
cx_2 = dag.add_child(h_1, 'cx', 'q0')
dag.add_edge(h_2, cx_2, 'q1')

def filter_function(node):
return node in ['cx', 'h']

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['cx'], ['h', 'cx'], ['h']], res)

def test_cx_h_cx(self):
dag = retworkx.PyDiGraph()
q0, q1 = dag.add_nodes_from(['q0', 'q1'])
cx_1 = dag.add_child(q0, 'cx', 'q0')
dag.add_edge(q1, cx_1, 'q1')
h_1 = dag.add_child(cx_1, 'h', 'q0')
cx_2 = dag.add_child(h_1, 'cx', 'q0')
dag.add_edge(cx_1, cx_2, 'q1')

def filter_function(node):
return node in ['cx', 'h']

res = retworkx.collect_runs(dag, filter_function)
self.assertEqual([['cx'], ['h', 'cx']], res)