Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ Return Iterator Types
.. autosummary::
:toctree: stubs

retworkx.BFSSuccessors
retworkx.BFSSuccessors
retworkx.NodeIndices
72 changes: 42 additions & 30 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use petgraph::visit::{
};

use super::dot_utils::build_dot;
use super::iterators::NodeIndices;
use super::{
is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes,
NoSuitableNeighbors, NodesRemoved,
Expand Down Expand Up @@ -510,10 +511,12 @@ impl PyDiGraph {
/// Return a list of all node indexes.
///
/// :returns: A list of all the node indexes in the graph
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "()"]
pub fn node_indexes(&self) -> Vec<usize> {
self.graph.node_indices().map(|node| node.index()).collect()
pub fn node_indexes(&self) -> NodeIndices {
NodeIndices {
nodes: self.graph.node_indices().map(|node| node.index()).collect(),
}
}

/// Return True if there is an edge from node_a to node_b.
Expand Down Expand Up @@ -1327,13 +1330,16 @@ impl PyDiGraph {
/// :param int node: The index of the node to get the neighbors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn neighbors(&self, node: usize) -> Vec<usize> {
self.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect()
pub fn neighbors(&self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect(),
}
}

/// Get the successor indices of a node.
Expand All @@ -1344,16 +1350,19 @@ impl PyDiGraph {
/// :param int node: The index of the node to get the successors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn successor_indices(&mut self, node: usize) -> Vec<usize> {
self.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Outgoing,
)
.map(|node| node.index())
.collect()
pub fn successor_indices(&mut self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Outgoing,
)
.map(|node| node.index())
.collect(),
}
}

/// Get the predecessor indices of a node.
Expand All @@ -1364,16 +1373,19 @@ impl PyDiGraph {
/// :param int node: The index of the node to get the predecessors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn predecessor_indices(&mut self, node: usize) -> Vec<usize> {
self.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Incoming,
)
.map(|node| node.index())
.collect()
pub fn predecessor_indices(&mut self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Incoming,
)
.map(|node| node.index())
.collect(),
}
}
/// Get the index and edge data for all parents of a node.
///
Expand Down Expand Up @@ -1424,15 +1436,15 @@ impl PyDiGraph {
/// as new nodes
///
/// :returns: A list of int indices of the newly created nodes
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(obj_list, /)"]
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> Vec<usize> {
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> NodeIndices {
let mut out_list: Vec<usize> = Vec::new();
for obj in obj_list {
let node_index = self.graph.add_node(obj);
out_list.push(node_index.index());
}
out_list
NodeIndices { nodes: out_list }
}

/// Remove nodes from the graph.
Expand Down
31 changes: 19 additions & 12 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use pyo3::types::{PyDict, PyList, PyLong, PyString, PyTuple};
use pyo3::Python;

use super::dot_utils::build_dot;
use super::iterators::NodeIndices;
use super::{NoEdgeBetweenNodes, NodesRemoved};

use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;
use petgraph::stable_graph::StableUnGraph;
Expand Down Expand Up @@ -355,10 +357,12 @@ impl PyGraph {
/// Return a list of all node indexes.
///
/// :returns: A list of all the node indexes in the graph
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "()"]
pub fn node_indexes(&self) -> Vec<usize> {
self.graph.node_indices().map(|node| node.index()).collect()
pub fn node_indexes(&self) -> NodeIndices {
NodeIndices {
nodes: self.graph.node_indices().map(|node| node.index()).collect(),
}
}

/// Return True if there is an edge between node_a to node_b.
Expand Down Expand Up @@ -719,15 +723,15 @@ impl PyGraph {
/// :param list obj_list: A list of python object to attach to the graph.
///
/// :returns indices: A list of int indices of the newly created nodes
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(obj_list, /)"]
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> Vec<usize> {
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> NodeIndices {
let mut out_list: Vec<usize> = Vec::new();
for obj in obj_list {
let node_index = self.graph.add_node(obj);
out_list.push(node_index.index());
}
out_list
NodeIndices { nodes: out_list }
}

/// Remove nodes from the graph.
Expand Down Expand Up @@ -783,13 +787,16 @@ impl PyGraph {
/// :param int node: The index of the node to get the neibhors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn neighbors(&self, node: usize) -> Vec<usize> {
self.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect()
pub fn neighbors(&self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect(),
}
}

/// Get the degree for a node
Expand Down
75 changes: 75 additions & 0 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,78 @@ impl PySequenceProtocol for BFSSuccessors {
}
}
}

/// A custom class for the return of node indices
///
/// This class is a container class for the results of functions that
/// return a list of node indices. It implements the Python sequence
/// protocol. So you can treat the return as a read-only sequence/list
/// that is integer indexed. If you want to use it as an iterator you
/// can by wrapping it in an ``iter()`` that will yield the results in
/// order.
///
/// For example::
///
/// import retworkx
///
/// graph = retworkx.generators.directed_path_graph(5)
/// nodes = retworkx.node_indexes(0)
/// # Index based access
/// third_element = nodes[2]
/// # Use as iterator
/// nodes_iter = iter(node)
/// first_element = next(nodes_iter)
/// second_element = next(nodes_iter)
///
#[pyclass(module = "retworkx")]
pub struct NodeIndices {
pub nodes: Vec<usize>,
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for NodeIndices {
fn __richcmp__(
&self,
other: &'p PySequence,
op: pyo3::basic::CompareOp,
) -> PyResult<bool> {
let compare = |other: &PySequence| -> PyResult<bool> {
if other.len()? as usize != self.nodes.len() {
return Ok(false);
}
for i in 0..self.nodes.len() {
let other_raw = other.get_item(i.try_into().unwrap())?;
let other_value: usize = other_raw.extract()?;
if other_value != self.nodes[i] {
return Ok(false);
}
}
Ok(true)
};
match op {
pyo3::basic::CompareOp::Eq => compare(other),
pyo3::basic::CompareOp::Ne => match compare(other) {
Ok(res) => Ok(!res),
Err(err) => Err(err),
},
_ => Err(PyNotImplementedError::new_err(
"Comparison not implemented",
)),
}
}
}

#[pyproto]
impl PySequenceProtocol for NodeIndices {
fn __len__(&self) -> PyResult<usize> {
Ok(self.nodes.len())
}

fn __getitem__(&'p self, idx: isize) -> PyResult<usize> {
if idx >= self.nodes.len().try_into().unwrap() {
Err(PyIndexError::new_err(format!("Invalid index, {}", idx)))
} else {
Ok(self.nodes[idx as usize])
}
}
}
34 changes: 22 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use rand_pcg::Pcg64;
use rayon::prelude::*;

use crate::generators::PyInit_generators;
use crate::iterators::NodeIndices;

trait NodesRemoved {
fn nodes_removed(&self) -> bool;
Expand Down Expand Up @@ -113,14 +114,16 @@ fn longest_path(graph: &digraph::PyDiGraph) -> PyResult<Vec<usize>> {
/// object must be a DAG without a cycle.
///
/// :returns: The node indices of the longest path on the DAG
/// :rtype: list
/// :rtype: NodeIndices
///
/// :raises Exception: If an unexpected error occurs or a path can't be found
/// :raises DAGHasCycle: If the input PyDiGraph has a cycle
#[pyfunction]
#[text_signature = "(graph, /)"]
fn dag_longest_path(graph: &digraph::PyDiGraph) -> PyResult<Vec<usize>> {
longest_path(graph)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like longest_path still returns PyResult<Vec<usize>>. Should that not be updated to NodeIndices?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easier to leave it returning Vec<usize> because dag_longest_path_length calls it too. I actually did that in the first rev before pushing this and it errored because there is no len() on NodeIndices it felt easier to leave it like this and have dag_longest_path just move the result into the NodeIndices struct as the return and leave it as a Vec<usize> for the other method

fn dag_longest_path(graph: &digraph::PyDiGraph) -> PyResult<NodeIndices> {
Ok(NodeIndices {
nodes: longest_path(graph)?,
})
}

/// Find the length of the longest path in a DAG
Expand Down Expand Up @@ -336,19 +339,21 @@ fn is_isomorphic_node_match(
/// :param PyDiGraph graph: The DAG to get the topological sort on
///
/// :returns: A list of node indices topologically sorted.
/// :rtype: list
/// :rtype: NodeIndices
///
/// :raises DAGHasCycle: if a cycle is encountered while sorting the graph
#[pyfunction]
#[text_signature = "(graph, /)"]
fn topological_sort(graph: &digraph::PyDiGraph) -> PyResult<Vec<usize>> {
fn topological_sort(graph: &digraph::PyDiGraph) -> PyResult<NodeIndices> {
let nodes = match algo::toposort(graph, None) {
Ok(nodes) => nodes,
Err(_err) => {
return Err(DAGHasCycle::new_err("Sort encountered a cycle"))
}
};
Ok(nodes.iter().map(|node| node.index()).collect())
Ok(NodeIndices {
nodes: nodes.iter().map(|node| node.index()).collect(),
})
}

fn dfs_edges<G>(graph: G, source: Option<usize>) -> Vec<(usize, usize)>
Expand Down Expand Up @@ -1808,7 +1813,7 @@ fn digraph_dijkstra_shortest_path_lengths(
///
/// :returns: The computed shortest path between node and finish as a list
/// of node indices.
/// :rtype: list
/// :rtype: NodeIndices
#[pyfunction]
#[text_signature = "(graph, node, goal_fn, edge_cost, estimate_cost, /)"]
fn graph_astar_shortest_path(
Expand All @@ -1818,7 +1823,7 @@ fn graph_astar_shortest_path(
goal_fn: PyObject,
edge_cost_fn: PyObject,
estimate_cost_fn: PyObject,
) -> PyResult<Vec<usize>> {
) -> PyResult<NodeIndices> {
let goal_fn_callable = |a: &PyObject| -> PyResult<bool> {
let res = goal_fn.call1(py, (a,))?;
let raw = res.to_object(py);
Expand Down Expand Up @@ -1858,7 +1863,9 @@ fn graph_astar_shortest_path(
))
}
};
Ok(path.1.into_iter().map(|x| x.index()).collect())
Ok(NodeIndices {
nodes: path.1.into_iter().map(|x| x.index()).collect(),
})
}

/// Compute the A* shortest path for a PyDiGraph
Expand All @@ -1880,7 +1887,7 @@ fn graph_astar_shortest_path(
///
/// :return: The computed shortest path between node and finish as a list
/// of node indices.
/// :rtype: list
/// :rtype: NodeIndices
#[pyfunction]
#[text_signature = "(graph, node, goal_fn, edge_cost, estimate_cost, /)"]
fn digraph_astar_shortest_path(
Expand All @@ -1890,7 +1897,7 @@ fn digraph_astar_shortest_path(
goal_fn: PyObject,
edge_cost_fn: PyObject,
estimate_cost_fn: PyObject,
) -> PyResult<Vec<usize>> {
) -> PyResult<NodeIndices> {
let goal_fn_callable = |a: &PyObject| -> PyResult<bool> {
let res = goal_fn.call1(py, (a,))?;
let raw = res.to_object(py);
Expand Down Expand Up @@ -1930,7 +1937,9 @@ fn digraph_astar_shortest_path(
))
}
};
Ok(path.1.into_iter().map(|x| x.index()).collect())
Ok(NodeIndices {
nodes: path.1.into_iter().map(|x| x.index()).collect(),
})
}

/// Return a :math:`G_{np}` directed random graph, also known as an
Expand Down Expand Up @@ -2551,6 +2560,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
m.add_class::<iterators::NodeIndices>()?;
m.add_wrapped(wrap_pymodule!(generators))?;
Ok(())
}
Loading