Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 50 additions & 40 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,25 +811,22 @@ where
/// path between two nodes then the corresponding matrix entry will be
/// ``np.inf``.
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"]
fn graph_floyd_warshall_numpy(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();
// Allocate empty matrix
let mut mat = Array2::<f64>::from_elem((n, n), std::f64::INFINITY);

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};

// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
mat[[j, i]] = mat[[j, i]].min(edge_weight);
}
Expand Down Expand Up @@ -879,27 +876,24 @@ fn graph_floyd_warshall_numpy(
/// path between two nodes then the corresponding matrix entry will be
/// ``np.inf``.
/// :rtype: numpy.ndarray
#[pyfunction(as_undirected = "false")]
#[text_signature = "(graph, weight_fn, /, as_undirected=False)"]
#[pyfunction(as_undirected = "false", default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None as_undirected=False, default_weight=1.0)"]
fn digraph_floyd_warshall_numpy(
py: Python,
graph: &digraph::PyDiGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
as_undirected: bool,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();

// Allocate empty matrix
let mut mat = Array2::<f64>::from_elem((n, n), std::f64::INFINITY);

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};

// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
if as_undirected {
mat[[j, i]] = mat[[j, i]].min(edge_weight);
Expand Down Expand Up @@ -1016,14 +1010,29 @@ fn layers(
Ok(PyList::new(py, output).into())
}

fn weight_callable(
py: Python,
weight_fn: &Option<PyObject>,
weight: PyObject,
default: f64,
) -> PyResult<f64> {
match weight_fn {
Some(weight_fn) => {
let res = weight_fn.call1(py, (weight,))?;
res.extract(py)
}
None => Ok(default),
}
}

/// Return the adjacency matrix for a PyDiGraph object
///
/// In the case where there are multiple edges between nodes the value in the
/// output matrix will be the sum of the edges' weights.
///
/// :param PyDiGraph graph: The DiGraph used to generate the adjacency matrix
/// from
/// :param weight_fn callable: A callable object (function, lambda, etc) which
/// :param callable weight_fn: A callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``. This
/// tells retworkx/rust how to extract a numerical weight as a ``float``
/// for edge object. Some simple examples are::
Expand All @@ -1034,26 +1043,27 @@ fn layers(
///
/// dag_adjacency_matrix(dag, weight_fn: lambda x: float(x))
///
/// to cast the edge object as a float as the weight.
/// to cast the edge object as a float as the weight. If this is not
/// specified a default value (either ``default_weight`` or 1) will be used
/// for all edges.
/// :param float default_weight: If ``weight_fn`` is not used this can be
/// optionally used to specify a default weight to use for all edges.
///
/// :return: The adjacency matrix for the input dag as a numpy array
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"]
fn digraph_adjacency_matrix(
py: Python,
graph: &digraph::PyDiGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();
let mut matrix = Array2::<f64>::zeros((n, n));

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
matrix[[i, j]] += edge_weight;
}
Ok(matrix.into_pyarray(py).into())
Expand All @@ -1076,30 +1086,30 @@ fn digraph_adjacency_matrix(
///
/// graph_adjacency_matrix(graph, weight_fn: lambda x: float(x))
///
/// to cast the edge object as a float as the weight.
/// to cast the edge object as a float as the weight. If this is not
/// specified a default value (either ``default_weight`` or 1) will be used
/// for all edges.
/// :param float default_weight: If ``weight_fn`` is not used this can be
/// optionally used to specify a default weight to use for all edges.
///
/// :return: The adjacency matrix for the input dag as a numpy array
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"]
fn graph_adjacency_matrix(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();
let mut matrix = Array2::<f64>::zeros((n, n));

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
matrix[[i, j]] += edge_weight;
matrix[[j, i]] += edge_weight;
}

Ok(matrix.into_pyarray(py).into())
}

Expand Down
56 changes: 56 additions & 0 deletions tests/test_adjacency_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,32 @@ def test_single_neighbor(self):
dtype=np.float64),
res))

def test_no_weight_fn(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
dag.add_child(node_a, 'b', {'a': 1})
dag.add_child(node_a, 'c', {'a': 2})
res = retworkx.digraph_adjacency_matrix(dag)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
dtype=np.float64),
res))

def test_default_weight(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
dag.add_child(node_a, 'b', {'a': 1})
dag.add_child(node_a, 'c', {'a': 2})
res = retworkx.digraph_adjacency_matrix(dag, default_weight=4)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 4.0, 4.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
dtype=np.float64),
res))

def test_float_cast_weight_func(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
Expand Down Expand Up @@ -88,6 +114,36 @@ def test_single_neighbor(self):
dtype=np.float64),
res))

def test_no_weight_fn(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, 'edge_a')
node_c = graph.add_node('c')
graph.add_edge(node_b, node_c, 'edge_b')
res = retworkx.graph_adjacency_matrix(graph)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]],
dtype=np.float64),
res))

def test_default_weight(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, 'edge_a')
node_c = graph.add_node('c')
graph.add_edge(node_b, node_c, 'edge_b')
res = retworkx.graph_adjacency_matrix(graph, default_weight=4)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 4.0, 0.0], [4.0, 0.0, 4.0], [0.0, 4.0, 0.0]],
dtype=np.float64),
res))

def test_float_cast_weight_func(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
Expand Down
40 changes: 40 additions & 0 deletions tests/test_floyd_warshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,43 @@ def test_floyd_warshall_numpy_graph_cycle_with_removals(self):
dist = retworkx.graph_floyd_warshall_numpy(graph, lambda x: 1)
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 3)

def test_floyd_warshall_numpy_digraph_cycle_no_weight_fn(self):
graph = retworkx.PyDiGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.digraph_floyd_warshall_numpy(graph)
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 4)

def test_floyd_warshall_numpy_graph_cycle_no_weight_fn(self):
graph = retworkx.PyGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.graph_floyd_warshall_numpy(graph)
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 3)

def test_floyd_warshall_numpy_digraph_cycle_default_weight(self):
graph = retworkx.PyDiGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.digraph_floyd_warshall_numpy(graph, default_weight=2)
self.assertEqual(dist[0, 3], 6)
self.assertEqual(dist[0, 4], 8)

def test_floyd_warshall_numpy_graph_cycle_default_weight(self):
graph = retworkx.PyGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.graph_floyd_warshall_numpy(graph, default_weight=2)
self.assertEqual(dist[0, 3], 6)
self.assertEqual(dist[0, 4], 6)