Skip to content

Commit 355efc1

Browse files
authored
Better handle cudf.pandas in from_pandas_edgelist (#4525)
Optimistically use cupy, but fall back to numpy if necessary. Also, bump lint versions. CC @rlratzel Authors: - Erik Welch (https://github.com/eriknw) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: #4525
1 parent 127d3be commit 355efc1

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

python/nx-cugraph/lint.yaml

+7-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626
- id: mixed-line-ending
2727
- id: trailing-whitespace
2828
- repo: https://github.com/abravalheri/validate-pyproject
29-
rev: v0.17
29+
rev: v0.18
3030
hooks:
3131
- id: validate-pyproject
3232
name: Validate pyproject.toml
@@ -40,7 +40,7 @@ repos:
4040
hooks:
4141
- id: isort
4242
- repo: https://github.com/asottile/pyupgrade
43-
rev: v3.15.2
43+
rev: v3.16.0
4444
hooks:
4545
- id: pyupgrade
4646
args: [--py39-plus]
@@ -50,18 +50,18 @@ repos:
5050
- id: black
5151
# - id: black-jupyter
5252
- repo: https://github.com/astral-sh/ruff-pre-commit
53-
rev: v0.4.4
53+
rev: v0.5.1
5454
hooks:
5555
- id: ruff
5656
args: [--fix-only, --show-fixes] # --unsafe-fixes]
5757
- repo: https://github.com/PyCQA/flake8
58-
rev: 7.0.0
58+
rev: 7.1.0
5959
hooks:
6060
- id: flake8
6161
args: ['--per-file-ignores=_nx_cugraph/__init__.py:E501', '--extend-ignore=SIM105'] # Why is this necessary?
6262
additional_dependencies: &flake8_dependencies
6363
# These versions need updated manually
64-
- flake8==7.0.0
64+
- flake8==7.1.0
6565
- flake8-bugbear==24.4.26
6666
- flake8-simplify==0.21.0
6767
- repo: https://github.com/asottile/yesqa
@@ -70,14 +70,14 @@ repos:
7070
- id: yesqa
7171
additional_dependencies: *flake8_dependencies
7272
- repo: https://github.com/codespell-project/codespell
73-
rev: v2.2.6
73+
rev: v2.3.0
7474
hooks:
7575
- id: codespell
7676
types_or: [python, rst, markdown]
7777
additional_dependencies: [tomli]
7878
files: ^(nx_cugraph|docs)/
7979
- repo: https://github.com/astral-sh/ruff-pre-commit
80-
rev: v0.4.4
80+
rev: v0.5.1
8181
hooks:
8282
- id: ruff
8383
- repo: https://github.com/pre-commit/pre-commit-hooks

python/nx-cugraph/nx_cugraph/classes/graph.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ def from_coo(
124124
# Easy and fast sanity checks
125125
if size != new_graph.dst_indices.size:
126126
raise ValueError
127-
for attr in ["edge_values", "edge_masks"]:
128-
if datadict := getattr(new_graph, attr):
127+
for edge_attr in ["edge_values", "edge_masks"]:
128+
if datadict := getattr(new_graph, edge_attr):
129129
for key, val in datadict.items():
130130
if val.shape[0] != size:
131131
raise ValueError(key)
132-
for attr in ["node_values", "node_masks"]:
133-
if datadict := getattr(new_graph, attr):
132+
for node_attr in ["node_values", "node_masks"]:
133+
if datadict := getattr(new_graph, node_attr):
134134
for key, val in datadict.items():
135135
if val.shape[0] != N:
136136
raise ValueError(key)

python/nx-cugraph/nx_cugraph/convert_matrix.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,35 @@ def from_pandas_edgelist(
3535
):
3636
"""cudf.DataFrame inputs also supported; value columns with str is unsuppported."""
3737
graph_class, inplace = _create_using_class(create_using)
38+
# Try to be optimal whether using pandas, cudf, or cudf.pandas
3839
src_array = df[source].to_numpy()
3940
dst_array = df[target].to_numpy()
41+
try:
42+
# Optimistically try to use cupy, but fall back to numpy if necessary
43+
src_array = cp.asarray(src_array)
44+
dst_array = cp.asarray(dst_array)
45+
np_or_cp = cp
46+
except ValueError:
47+
src_array = np.asarray(src_array)
48+
dst_array = np.asarray(dst_array)
49+
np_or_cp = np
4050
# TODO: create renumbering helper function(s)
4151
# Renumber step 0: node keys
42-
nodes = np.unique(np.concatenate([src_array, dst_array]))
52+
nodes = np_or_cp.unique(np_or_cp.concatenate([src_array, dst_array]))
4353
N = nodes.size
4454
kwargs = {}
4555
if N > 0 and (
4656
nodes[0] != 0
4757
or nodes[N - 1] != N - 1
4858
or (
4959
nodes.dtype.kind not in {"i", "u"}
50-
and not (nodes == np.arange(N, dtype=np.int64)).all()
60+
and not (nodes == np_or_cp.arange(N, dtype=np.int64)).all()
5161
)
5262
):
53-
# We need to renumber indices--np.searchsorted to the rescue!
63+
# We need to renumber indices--np_or_cp.searchsorted to the rescue!
5464
kwargs["id_to_key"] = nodes.tolist()
55-
src_indices = cp.array(np.searchsorted(nodes, src_array), index_dtype)
56-
dst_indices = cp.array(np.searchsorted(nodes, dst_array), index_dtype)
65+
src_indices = cp.asarray(np_or_cp.searchsorted(nodes, src_array), index_dtype)
66+
dst_indices = cp.asarray(np_or_cp.searchsorted(nodes, dst_array), index_dtype)
5767
else:
5868
src_indices = cp.array(src_array)
5969
dst_indices = cp.array(dst_array)

0 commit comments

Comments
 (0)