Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ Algorithm Functions
retworkx.digraph_all_simple_paths
retworkx.graph_astar_shortest_path
retworkx.digraph_astar_shortest_path
retworkx.graph_dijkstra_shortest_paths
retworkx.digraph_dijkstra_shortest_paths
retworkx.graph_dijkstra_shortest_path_lengths
retworkx.digraph_dijkstra_shortest_path_lengths
retworkx.graph_k_shortest_path_lengths
Expand Down
23 changes: 21 additions & 2 deletions src/dijkstra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// License for the specific language governing permissions and limitations
// under the License.

// This module is copied and forked from the upstream petgraph repository,
// specifically:
// This module was originally copied and forked from the upstream petgraph
// repository, specifically:
// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs
// this was necessary to modify the error handling to allow python callables
// to be use for the input functions for edge_cost and return any exceptions
Expand Down Expand Up @@ -42,6 +42,11 @@ use crate::astar::MinScored;
/// If `goal` is not `None`, then the algorithm terminates once the `goal` node's
/// cost is calculated.
///
/// If `path` is not `None`, then the algorithm will mutate the input
/// hashbrown::HashMap to insert an entry where the index is the dest node index
/// the value is a Vec of node indices of the path starting with `start` and
/// ending at the index.
///
/// Returns a `HashMap` that maps `NodeId` to path cost.
/// # Example
/// ```rust
Expand Down Expand Up @@ -97,6 +102,7 @@ pub fn dijkstra<G, F, K>(
start: G::NodeId,
goal: Option<G::NodeId>,
mut edge_cost: F,
mut path: Option<&mut HashMap<G::NodeId, Vec<G::NodeId>>>,
Comment thread
mtreinish marked this conversation as resolved.
) -> PyResult<HashMap<G::NodeId, K>>
where
G: IntoEdges + Visitable,
Expand All @@ -110,6 +116,9 @@ where
let zero_score = K::default();
scores.insert(start, zero_score);
visit_next.push(MinScored(zero_score, start));
if path.is_some() {
path.as_mut().unwrap().insert(start, vec![start]);
}
while let Some(MinScored(node_score, node)) = visit_next.pop() {
if visited.is_visited(&node) {
continue;
Expand All @@ -134,6 +143,16 @@ where
Vacant(ent) => {
ent.insert(next_score);
visit_next.push(MinScored(next_score, next));
if path.is_some() {
let mut node_path =
path.as_mut().unwrap().get(&node).unwrap().clone();
path.as_mut().unwrap().entry(next).or_insert({
let mut new_vec: Vec<G::NodeId> = Vec::new();
new_vec.append(&mut node_path);
new_vec.push(next);
new_vec
});
}
}
}
}
Expand Down
202 changes: 177 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ fn graph_floyd_warshall_numpy(
// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
mat[[j, i]] = mat[[j, i]].min(edge_weight);
}
Expand Down Expand Up @@ -893,7 +893,7 @@ fn digraph_floyd_warshall_numpy(
// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
if as_undirected {
mat[[j, i]] = mat[[j, i]].min(edge_weight);
Expand Down Expand Up @@ -1167,21 +1167,6 @@ pub fn graph_distance_matrix(
Ok(matrix.into_pyarray(py).into())
}

fn weight_callable(
py: Python,
weight_fn: &Option<PyObject>,
weight: PyObject,
default: f64,
) -> PyResult<f64> {
match weight_fn {
Some(weight_fn) => {
let res = weight_fn.call1(py, (weight,))?;
res.extract(py)
}
None => Ok(default),
}
}

/// Return the adjacency matrix for a PyDiGraph object
///
/// In the case where there are multiple edges between nodes the value in the
Expand Down Expand Up @@ -1220,7 +1205,7 @@ fn digraph_adjacency_matrix(
let mut matrix = Array2::<f64>::zeros((n, n));
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
matrix[[i, j]] += edge_weight;
}
Ok(matrix.into_pyarray(py).into())
Expand Down Expand Up @@ -1263,7 +1248,7 @@ fn graph_adjacency_matrix(
let mut matrix = Array2::<f64>::zeros((n, n));
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
matrix[[i, j]] += edge_weight;
matrix[[j, i]] += edge_weight;
}
Expand Down Expand Up @@ -1384,6 +1369,163 @@ fn digraph_all_simple_paths(
Ok(result)
}

fn weight_callable(
py: Python,
weight_fn: &Option<PyObject>,
weight: &PyObject,
Comment thread
lcapelluto marked this conversation as resolved.
default: f64,
) -> PyResult<f64> {
match weight_fn {
Some(weight_fn) => {
let res = weight_fn.call1(py, (weight,))?;
res.extract(py)
}
None => Ok(default),
}
}

/// Find the shortest path from a node
///
/// This function will generate the shortest path from a source node using
/// Dijkstra's algorithm.
///
/// :param PyGraph graph:
/// :param int source: The node index to find paths from
/// :param int target: An optional target to find a path to
/// :param weight_fn: An optional weight function for an edge. It will accept
/// a single argument, the edge's weight object and will return a float which
/// will be used to represent the weight/cost of the edge
/// :param float default_weight: If ``weight_fn`` isn't specified this optional
/// float value will be used for the weight/cost of each edge.
/// :param bool as_undirected: If set to true the graph will be treated as
/// undirected for finding the shortest path.
///
/// :return: Dictionary of paths. The keys are destination node indices and
/// the dict values are lists of node indices making the path.
/// :rtype: dict
#[pyfunction(default_weight = "1.0", as_undirected = "false")]
#[text_signature = "(graph, source, /, target=None weight_fn=None, default_weight=1.0)"]
pub fn graph_dijkstra_shortest_paths(
py: Python,
graph: &graph::PyGraph,
source: usize,
target: Option<usize>,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let start = NodeIndex::new(source);
let goal_index: Option<NodeIndex> = match target {
Some(node) => Some(NodeIndex::new(node)),
None => None,
};
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
dijkstra::dijkstra(
graph,
start,
goal_index,
|e| weight_callable(py, &weight_fn, e.weight(), default_weight),
Some(&mut paths),
)?;

let out_dict = PyDict::new(py);
for (index, value) in paths {
let int_index = index.index();
if int_index == source {
continue;
}
if (target.is_some() && target.unwrap() == int_index)
|| target.is_none()
{
out_dict.set_item(
int_index,
value
.iter()
.map(|index| index.index())
.collect::<Vec<usize>>(),
)?;
}
}
Ok(out_dict.into())
}

/// Find the shortest path from a node
///
/// This function will generate the shortest path from a source node using
/// Dijkstra's algorithm.
///
/// :param PyDiGraph graph:
/// :param int source: The node index to find paths from
/// :param int target: An optional target path to find the path
/// :param weight_fn: An optional weight function for an edge. It will accept
/// a single argument, the edge's weight object and will return a float which
/// will be used to represent the weight/cost of the edge
/// :param float default_weight: If ``weight_fn`` isn't specified this optional
/// float value will be used for the weight/cost of each edge.
/// :param bool as_undirected: If set to true the graph will be treated as
/// undirected for finding the shortest path.
///
/// :return: Dictionary of paths. The keys are destination node indices and
/// the dict values are lists of node indices making the path.
/// :rtype: dict
#[pyfunction(default_weight = "1.0", as_undirected = "false")]
#[text_signature = "(graph, source, /, target=None weight_fn=None, default_weight=1.0, as_undirected=False)"]
pub fn digraph_dijkstra_shortest_paths(
py: Python,
graph: &digraph::PyDiGraph,
source: usize,
target: Option<usize>,
weight_fn: Option<PyObject>,
default_weight: f64,
as_undirected: bool,
) -> PyResult<PyObject> {
let start = NodeIndex::new(source);
let goal_index: Option<NodeIndex> = match target {
Some(node) => Some(NodeIndex::new(node)),
None => None,
};
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
if as_undirected {
dijkstra::dijkstra(
// TODO: Use petgraph undirected adapter after
// https://github.com/petgraph/petgraph/pull/318 is available in
// a petgraph release.
&graph.to_undirected(py),
start,
goal_index,
|e| weight_callable(py, &weight_fn, e.weight(), default_weight),
Some(&mut paths),
)?;
} else {
dijkstra::dijkstra(
graph,
start,
goal_index,
|e| weight_callable(py, &weight_fn, e.weight(), default_weight),
Some(&mut paths),
)?;
}

let out_dict = PyDict::new(py);
for (index, value) in paths {
let int_index = index.index();
if int_index == source {
continue;
}
if (target.is_some() && target.unwrap() == int_index)
|| target.is_none()
{
out_dict.set_item(
int_index,
value
.iter()
.map(|index| index.index())
.collect::<Vec<usize>>(),
)?;
}
}
Ok(out_dict.into())
}

/// Compute the lengths of the shortest paths for a PyGraph object using
/// Dijkstra's algorithm
///
Expand Down Expand Up @@ -1423,9 +1565,13 @@ fn graph_dijkstra_shortest_path_lengths(
None => None,
};

let res = dijkstra::dijkstra(graph, start, goal_index, |e| {
edge_cost_callable(e.weight())
})?;
let res = dijkstra::dijkstra(
graph,
start,
goal_index,
|e| edge_cost_callable(e.weight()),
None,
)?;
let out_dict = PyDict::new(py);
for (index, value) in res {
let int_index = index.index();
Expand Down Expand Up @@ -1478,9 +1624,13 @@ fn digraph_dijkstra_shortest_path_lengths(
None => None,
};

let res = dijkstra::dijkstra(graph, start, goal_index, |e| {
edge_cost_callable(e.weight())
})?;
let res = dijkstra::dijkstra(
graph,
start,
goal_index,
|e| edge_cost_callable(e.weight()),
None,
)?;
let out_dict = PyDict::new(py);
for (index, value) in res {
let int_index = index.index();
Expand Down Expand Up @@ -2234,6 +2384,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(graph_adjacency_matrix))?;
m.add_wrapped(wrap_pyfunction!(graph_all_simple_paths))?;
m.add_wrapped(wrap_pyfunction!(digraph_all_simple_paths))?;
m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_paths))?;
m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_shortest_paths))?;
m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_path_lengths))?;
m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_shortest_path_lengths))?;
m.add_wrapped(wrap_pyfunction!(graph_astar_shortest_path))?;
Expand Down
Loading