Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,17 +824,20 @@ fn graph_floyd_warshall_numpy(
/// graph_floyd_warshall_numpy(graph, weight_fn: lambda x: float(x))
///
/// to cast the edge object as a float as the weight.
/// :param as_undirected: If set to true each directed edge will be treated as
/// bidirectional/undirected.
///
/// :returns: A matrix of shortest path distances between nodes. If there is no
/// path between two nodes then the corresponding matrix entry will be
/// ``np.inf``.
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(as_undirected = "false")]
#[text_signature = "(graph, weight_fn, /, as_undirected=False)"]
fn digraph_floyd_warshall_numpy(
py: Python,
graph: &digraph::PyDiGraph,
weight_fn: PyObject,
as_undirected: bool,
) -> PyResult<PyObject> {
let n = graph.node_count();

Expand All @@ -850,6 +853,9 @@ fn digraph_floyd_warshall_numpy(
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
if as_undirected {
mat[[j, i]] = mat[[j, i]].min(edge_weight);
}
}
// 0 out the diagonal
for x in mat.diag_mut() {
Expand Down
16 changes: 16 additions & 0 deletions tests/test_floyd_warshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,22 @@ def test_floyd_warshall_numpy_cycle(self):
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 3)

def test_directed_floyd_warshall_numpy_cycle_as_undirected(self):
graph = retworkx.PyDiGraph()
graph.add_nodes_from(list(range(7)))
graph.add_edges_from_no_data(
[(0, 1), (0, 6), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)])
dist = retworkx.digraph_floyd_warshall_numpy(graph, lambda x: 1,
as_undirected=True)
expected = numpy.array([[0., 1., 2., 3., 3., 2., 1.],
[1., 0., 1., 2., 3., 3., 2.],
[2., 1., 0., 1., 2., 3., 3.],
[3., 2., 1., 0., 1., 2., 3.],
[3., 3., 2., 1., 0., 1., 2.],
[2., 3., 3., 2., 1., 0., 1.],
[1., 2., 3., 3., 2., 1., 0.]])
self.assertTrue(numpy.array_equal(dist, expected))

def test_floyd_warshall_numpy_digraph_three_edges(self):
graph = retworkx.PyDiGraph()
graph.add_nodes_from(list(range(6)))
Expand Down