diff --git a/src/lib.rs b/src/lib.rs index 33c17ac45a..b29545ef75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -811,25 +811,22 @@ where /// path between two nodes then the corresponding matrix entry will be /// ``np.inf``. /// :rtype: numpy.ndarray -#[pyfunction] -#[text_signature = "(graph, weight_fn, /)"] +#[pyfunction(default_weight = "1.0")] +#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"] fn graph_floyd_warshall_numpy( py: Python, graph: &graph::PyGraph, - weight_fn: PyObject, + weight_fn: Option, + default_weight: f64, ) -> PyResult { let n = graph.node_count(); // Allocate empty matrix let mut mat = Array2::::from_elem((n, n), std::f64::INFINITY); - let weight_callable = |a: &PyObject| -> PyResult { - let res = weight_fn.call1(py, (a,))?; - res.extract(py) - }; - // Build adjacency matrix for (i, j, weight) in get_edge_iter_with_weights(graph) { - let edge_weight = weight_callable(&weight)?; + let edge_weight = + weight_callable(py, &weight_fn, weight, default_weight)?; mat[[i, j]] = mat[[i, j]].min(edge_weight); mat[[j, i]] = mat[[j, i]].min(edge_weight); } @@ -879,27 +876,24 @@ fn graph_floyd_warshall_numpy( /// path between two nodes then the corresponding matrix entry will be /// ``np.inf``. /// :rtype: numpy.ndarray -#[pyfunction(as_undirected = "false")] -#[text_signature = "(graph, weight_fn, /, as_undirected=False)"] +#[pyfunction(as_undirected = "false", default_weight = "1.0")] +#[text_signature = "(graph, /, weight_fn=None as_undirected=False, default_weight=1.0)"] fn digraph_floyd_warshall_numpy( py: Python, graph: &digraph::PyDiGraph, - weight_fn: PyObject, + weight_fn: Option, as_undirected: bool, + default_weight: f64, ) -> PyResult { let n = graph.node_count(); // Allocate empty matrix let mut mat = Array2::::from_elem((n, n), std::f64::INFINITY); - let weight_callable = |a: &PyObject| -> PyResult { - let res = weight_fn.call1(py, (a,))?; - res.extract(py) - }; - // Build adjacency matrix for (i, j, weight) in get_edge_iter_with_weights(graph) { - let edge_weight = weight_callable(&weight)?; + let edge_weight = + weight_callable(py, &weight_fn, weight, default_weight)?; mat[[i, j]] = mat[[i, j]].min(edge_weight); if as_undirected { mat[[j, i]] = mat[[j, i]].min(edge_weight); @@ -1016,6 +1010,21 @@ fn layers( Ok(PyList::new(py, output).into()) } +fn weight_callable( + py: Python, + weight_fn: &Option, + weight: PyObject, + default: f64, +) -> PyResult { + match weight_fn { + Some(weight_fn) => { + let res = weight_fn.call1(py, (weight,))?; + res.extract(py) + } + None => Ok(default), + } +} + /// Return the adjacency matrix for a PyDiGraph object /// /// In the case where there are multiple edges between nodes the value in the @@ -1023,7 +1032,7 @@ fn layers( /// /// :param PyDiGraph graph: The DiGraph used to generate the adjacency matrix /// from -/// :param weight_fn callable: A callable object (function, lambda, etc) which +/// :param callable weight_fn: A callable object (function, lambda, etc) which /// will be passed the edge object and expected to return a ``float``. This /// tells retworkx/rust how to extract a numerical weight as a ``float`` /// for edge object. Some simple examples are:: @@ -1034,26 +1043,27 @@ fn layers( /// /// dag_adjacency_matrix(dag, weight_fn: lambda x: float(x)) /// -/// to cast the edge object as a float as the weight. +/// to cast the edge object as a float as the weight. If this is not +/// specified a default value (either ``default_weight`` or 1) will be used +/// for all edges. +/// :param float default_weight: If ``weight_fn`` is not used this can be +/// optionally used to specify a default weight to use for all edges. /// /// :return: The adjacency matrix for the input dag as a numpy array /// :rtype: numpy.ndarray -#[pyfunction] -#[text_signature = "(graph, weight_fn, /)"] +#[pyfunction(default_weight = "1.0")] +#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"] fn digraph_adjacency_matrix( py: Python, graph: &digraph::PyDiGraph, - weight_fn: PyObject, + weight_fn: Option, + default_weight: f64, ) -> PyResult { let n = graph.node_count(); let mut matrix = Array2::::zeros((n, n)); - - let weight_callable = |a: &PyObject| -> PyResult { - let res = weight_fn.call1(py, (a,))?; - res.extract(py) - }; for (i, j, weight) in get_edge_iter_with_weights(graph) { - let edge_weight = weight_callable(&weight)?; + let edge_weight = + weight_callable(py, &weight_fn, weight, default_weight)?; matrix[[i, j]] += edge_weight; } Ok(matrix.into_pyarray(py).into()) @@ -1076,30 +1086,30 @@ fn digraph_adjacency_matrix( /// /// graph_adjacency_matrix(graph, weight_fn: lambda x: float(x)) /// -/// to cast the edge object as a float as the weight. +/// to cast the edge object as a float as the weight. If this is not +/// specified a default value (either ``default_weight`` or 1) will be used +/// for all edges. +/// :param float default_weight: If ``weight_fn`` is not used this can be +/// optionally used to specify a default weight to use for all edges. /// /// :return: The adjacency matrix for the input dag as a numpy array /// :rtype: numpy.ndarray -#[pyfunction] -#[text_signature = "(graph, weight_fn, /)"] +#[pyfunction(default_weight = "1.0")] +#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"] fn graph_adjacency_matrix( py: Python, graph: &graph::PyGraph, - weight_fn: PyObject, + weight_fn: Option, + default_weight: f64, ) -> PyResult { let n = graph.node_count(); let mut matrix = Array2::::zeros((n, n)); - - let weight_callable = |a: &PyObject| -> PyResult { - let res = weight_fn.call1(py, (a,))?; - res.extract(py) - }; for (i, j, weight) in get_edge_iter_with_weights(graph) { - let edge_weight = weight_callable(&weight)?; + let edge_weight = + weight_callable(py, &weight_fn, weight, default_weight)?; matrix[[i, j]] += edge_weight; matrix[[j, i]] += edge_weight; } - Ok(matrix.into_pyarray(py).into()) } diff --git a/tests/test_adjacency_matrix.py b/tests/test_adjacency_matrix.py index e9c927fcb0..ea6f5374b9 100644 --- a/tests/test_adjacency_matrix.py +++ b/tests/test_adjacency_matrix.py @@ -30,6 +30,32 @@ def test_single_neighbor(self): dtype=np.float64), res)) + def test_no_weight_fn(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + dag.add_child(node_a, 'b', {'a': 1}) + dag.add_child(node_a, 'c', {'a': 2}) + res = retworkx.digraph_adjacency_matrix(dag) + self.assertIsInstance(res, np.ndarray) + self.assertTrue(np.array_equal( + np.array( + [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + dtype=np.float64), + res)) + + def test_default_weight(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + dag.add_child(node_a, 'b', {'a': 1}) + dag.add_child(node_a, 'c', {'a': 2}) + res = retworkx.digraph_adjacency_matrix(dag, default_weight=4) + self.assertIsInstance(res, np.ndarray) + self.assertTrue(np.array_equal( + np.array( + [[0.0, 4.0, 4.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + dtype=np.float64), + res)) + def test_float_cast_weight_func(self): dag = retworkx.PyDAG() node_a = dag.add_node('a') @@ -88,6 +114,36 @@ def test_single_neighbor(self): dtype=np.float64), res)) + def test_no_weight_fn(self): + graph = retworkx.PyGraph() + node_a = graph.add_node('a') + node_b = graph.add_node('b') + graph.add_edge(node_a, node_b, 'edge_a') + node_c = graph.add_node('c') + graph.add_edge(node_b, node_c, 'edge_b') + res = retworkx.graph_adjacency_matrix(graph) + self.assertIsInstance(res, np.ndarray) + self.assertTrue(np.array_equal( + np.array( + [[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]], + dtype=np.float64), + res)) + + def test_default_weight(self): + graph = retworkx.PyGraph() + node_a = graph.add_node('a') + node_b = graph.add_node('b') + graph.add_edge(node_a, node_b, 'edge_a') + node_c = graph.add_node('c') + graph.add_edge(node_b, node_c, 'edge_b') + res = retworkx.graph_adjacency_matrix(graph, default_weight=4) + self.assertIsInstance(res, np.ndarray) + self.assertTrue(np.array_equal( + np.array( + [[0.0, 4.0, 0.0], [4.0, 0.0, 4.0], [0.0, 4.0, 0.0]], + dtype=np.float64), + res)) + def test_float_cast_weight_func(self): graph = retworkx.PyGraph() node_a = graph.add_node('a') diff --git a/tests/test_floyd_warshall.py b/tests/test_floyd_warshall.py index b33b71f9bf..1ae105d186 100644 --- a/tests/test_floyd_warshall.py +++ b/tests/test_floyd_warshall.py @@ -231,3 +231,43 @@ def test_floyd_warshall_numpy_graph_cycle_with_removals(self): dist = retworkx.graph_floyd_warshall_numpy(graph, lambda x: 1) self.assertEqual(dist[0, 3], 3) self.assertEqual(dist[0, 4], 3) + + def test_floyd_warshall_numpy_digraph_cycle_no_weight_fn(self): + graph = retworkx.PyDiGraph() + graph.add_nodes_from(list(range(8))) + graph.remove_node(0) + graph.add_edges_from_no_data( + [(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]) + dist = retworkx.digraph_floyd_warshall_numpy(graph) + self.assertEqual(dist[0, 3], 3) + self.assertEqual(dist[0, 4], 4) + + def test_floyd_warshall_numpy_graph_cycle_no_weight_fn(self): + graph = retworkx.PyGraph() + graph.add_nodes_from(list(range(8))) + graph.remove_node(0) + graph.add_edges_from_no_data( + [(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]) + dist = retworkx.graph_floyd_warshall_numpy(graph) + self.assertEqual(dist[0, 3], 3) + self.assertEqual(dist[0, 4], 3) + + def test_floyd_warshall_numpy_digraph_cycle_default_weight(self): + graph = retworkx.PyDiGraph() + graph.add_nodes_from(list(range(8))) + graph.remove_node(0) + graph.add_edges_from_no_data( + [(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]) + dist = retworkx.digraph_floyd_warshall_numpy(graph, default_weight=2) + self.assertEqual(dist[0, 3], 6) + self.assertEqual(dist[0, 4], 8) + + def test_floyd_warshall_numpy_graph_cycle_default_weight(self): + graph = retworkx.PyGraph() + graph.add_nodes_from(list(range(8))) + graph.remove_node(0) + graph.add_edges_from_no_data( + [(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]) + dist = retworkx.graph_floyd_warshall_numpy(graph, default_weight=2) + self.assertEqual(dist[0, 3], 6) + self.assertEqual(dist[0, 4], 6)