Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ Return Iterator Types
.. autosummary::
:toctree: stubs

retworkx.BFSSuccessors
retworkx.BFSSuccessors
retworkx.NodeIndices
72 changes: 42 additions & 30 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use petgraph::visit::{
};

use super::dot_utils::build_dot;
use super::iterators::NodeIndices;
use super::{
is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes,
NoSuitableNeighbors, NodesRemoved,
Expand Down Expand Up @@ -510,10 +511,12 @@ impl PyDiGraph {
/// Return a list of all node indexes.
///
/// :returns: A list of all the node indexes in the graph
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "()"]
pub fn node_indexes(&self) -> Vec<usize> {
self.graph.node_indices().map(|node| node.index()).collect()
pub fn node_indexes(&self) -> NodeIndices {
NodeIndices {
nodes: self.graph.node_indices().map(|node| node.index()).collect(),
}
}

/// Return True if there is an edge from node_a to node_b.
Expand Down Expand Up @@ -1327,13 +1330,16 @@ impl PyDiGraph {
/// :param int node: The index of the node to get the neighbors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn neighbors(&self, node: usize) -> Vec<usize> {
self.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect()
pub fn neighbors(&self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect(),
}
}

/// Get the successor indices of a node.
Expand All @@ -1344,16 +1350,19 @@ impl PyDiGraph {
/// :param int node: The index of the node to get the successors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn successor_indices(&mut self, node: usize) -> Vec<usize> {
self.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Outgoing,
)
.map(|node| node.index())
.collect()
pub fn successor_indices(&mut self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Outgoing,
)
.map(|node| node.index())
.collect(),
}
}

/// Get the predecessor indices of a node.
Expand All @@ -1364,16 +1373,19 @@ impl PyDiGraph {
/// :param int node: The index of the node to get the predecessors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn predecessor_indices(&mut self, node: usize) -> Vec<usize> {
self.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Incoming,
)
.map(|node| node.index())
.collect()
pub fn predecessor_indices(&mut self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors_directed(
NodeIndex::new(node),
petgraph::Direction::Incoming,
)
.map(|node| node.index())
.collect(),
}
}
/// Get the index and edge data for all parents of a node.
///
Expand Down Expand Up @@ -1424,15 +1436,15 @@ impl PyDiGraph {
/// as new nodes
///
/// :returns: A list of int indices of the newly created nodes
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(obj_list, /)"]
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> Vec<usize> {
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> NodeIndices {
let mut out_list: Vec<usize> = Vec::new();
for obj in obj_list {
let node_index = self.graph.add_node(obj);
out_list.push(node_index.index());
}
out_list
NodeIndices { nodes: out_list }
}

/// Remove nodes from the graph.
Expand Down
31 changes: 19 additions & 12 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use pyo3::types::{PyDict, PyList, PyLong, PyString, PyTuple};
use pyo3::Python;

use super::dot_utils::build_dot;
use super::iterators::NodeIndices;
use super::{NoEdgeBetweenNodes, NodesRemoved};

use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;
use petgraph::stable_graph::StableUnGraph;
Expand Down Expand Up @@ -355,10 +357,12 @@ impl PyGraph {
/// Return a list of all node indexes.
///
/// :returns: A list of all the node indexes in the graph
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "()"]
pub fn node_indexes(&self) -> Vec<usize> {
self.graph.node_indices().map(|node| node.index()).collect()
pub fn node_indexes(&self) -> NodeIndices {
NodeIndices {
nodes: self.graph.node_indices().map(|node| node.index()).collect(),
}
}

/// Return True if there is an edge between node_a to node_b.
Expand Down Expand Up @@ -719,15 +723,15 @@ impl PyGraph {
/// :param list obj_list: A list of python object to attach to the graph.
///
/// :returns indices: A list of int indices of the newly created nodes
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(obj_list, /)"]
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> Vec<usize> {
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> NodeIndices {
let mut out_list: Vec<usize> = Vec::new();
for obj in obj_list {
let node_index = self.graph.add_node(obj);
out_list.push(node_index.index());
}
out_list
NodeIndices { nodes: out_list }
}

/// Remove nodes from the graph.
Expand Down Expand Up @@ -783,13 +787,16 @@ impl PyGraph {
/// :param int node: The index of the node to get the neibhors of
///
/// :returns: A list of the neighbor node indicies
/// :rtype: list
/// :rtype: NodeIndices
#[text_signature = "(node, /)"]
pub fn neighbors(&self, node: usize) -> Vec<usize> {
self.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect()
pub fn neighbors(&self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect(),
}
}

/// Get the degree for a node
Expand Down
80 changes: 78 additions & 2 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

use std::convert::TryInto;

use pyo3::class::PySequenceProtocol;
use pyo3::exceptions::PyIndexError;
use pyo3::class::{PyObjectProtocol, PySequenceProtocol};
use pyo3::exceptions::{PyIndexError, PyNotImplementedError};
use pyo3::prelude::*;
use pyo3::types::PySequence;

/// A custom class for the return from :func:`retworkx.bfs_successors`
///
Expand Down Expand Up @@ -61,3 +62,78 @@ impl PySequenceProtocol for BFSSuccessors {
}
}
}

/// A custom class for the return if node indices
Comment thread
mtreinish marked this conversation as resolved.
Outdated
///
/// This class is a container class for the results of functions that
/// return a list of node indices. It implements the Python sequence
/// protocol. So you can treat the return as read-only a sequence/list
Comment thread
mtreinish marked this conversation as resolved.
Outdated
/// that is integer indexed. If you want to use it as an iterator you
/// can by wrapping it in an ``iter()`` that will yield the results in
/// order.
///
/// For example::
///
/// import retworkx
///
/// graph = retworkx.generators.directed_path_graph(5)
/// nodes = retworkx.node_indexes(0)
/// # Index based access
/// third_element = nodes[2]
/// # Use as iterator
/// nodes_iter = iter(node)
/// first_element = next(nodes_iter)
/// second_element = next(nodes_iter)
///
#[pyclass(module = "retworkx")]
pub struct NodeIndices {
pub nodes: Vec<usize>,
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for NodeIndices {
fn __richcmp__(
&self,
other: &'p PySequence,
op: pyo3::basic::CompareOp,
) -> PyResult<bool> {
let compare = |other: &PySequence| -> PyResult<bool> {
if other.len()? as usize != self.nodes.len() {
return Ok(false);
}
for i in 0..self.nodes.len() {
let other_raw = other.get_item(i.try_into().unwrap())?;
let other_value: usize = other_raw.extract()?;
if other_value != self.nodes[i] {
return Ok(false);
}
}
Ok(true)
};
match op {
pyo3::basic::CompareOp::Eq => compare(other),
pyo3::basic::CompareOp::Ne => match compare(other) {
Ok(res) => Ok(!res),
Err(err) => Err(err),
},
_ => Err(PyNotImplementedError::new_err(
"Comparison not implemented",
)),
}
}
}

#[pyproto]
impl PySequenceProtocol for NodeIndices {
fn __len__(&self) -> PyResult<usize> {
Ok(self.nodes.len())
}

fn __getitem__(&'p self, idx: isize) -> PyResult<usize> {
if idx >= self.nodes.len().try_into().unwrap() {
Err(PyIndexError::new_err(format!("Invalid index, {}", idx)))
} else {
Ok(self.nodes[idx as usize])
}
}
}
Loading