diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b040d4cc..f284c3df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,14 +19,18 @@ repos: rev: v0.14.3 hooks: - id: ruff-check - args: [--fix] + args: [--fix, --config, "pyproject.toml"] - id: ruff-format + args: [--config, "pyproject.toml"] - repo: https://github.com/asottile/yesqa rev: v1.3.0 hooks: - id: yesqa additional_dependencies: - flake8==7.1.1 + exclude: | + (?x) + python/pylibwholegraph/pylibwholegraph/_doctor_check[.]py$ - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.4 hooks: diff --git a/ci/download-torch-wheels.sh b/ci/download-torch-wheels.sh new file mode 100755 index 00000000..82b22787 --- /dev/null +++ b/ci/download-torch-wheels.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# [description] +# +# Downloads a CUDA variant of 'torch' from the correct index, based on CUDA major version. +# +# This exists to avoid using 'pip --extra-index-url', which has these undesirable properties: +# +# - allows for CPU-only 'torch' to be downloaded from pypi.org +# - allows for other non-torch packages like 'numpy' to be downloaded from the PyTorch indices +# - increases solve complexity for 'pip' +# + +set -e -u -o pipefail + +TORCH_WHEEL_DIR="${1}" + +# skip download attempt on CUDA versions where we know there isn't a 'torch' CUDA wheel. +CUDA_MAJOR="${RAPIDS_CUDA_VERSION%%.*}" +CUDA_MINOR=$(echo "${RAPIDS_CUDA_VERSION}" | cut -d'.' -f2) +if \ + { [ "${CUDA_MAJOR}" -eq 12 ] && [ "${CUDA_MINOR}" -lt 9 ]; } \ + || { [ "${CUDA_MAJOR}" -eq 13 ] && [ "${CUDA_MINOR}" -gt 0 ]; } \ + || [ "${CUDA_MAJOR}" -gt 13 ]; +then + rapids-logger "Skipping 'torch' wheel download. (requires CUDA 12.9+ or 13.0, found ${RAPIDS_CUDA_VERSION})" + exit 0 +fi + +# Ensure CUDA-enabled 'torch' packages are always used. +# +# Downloading + passing the downloaded file as a requirement forces the use of this +# package and ensures 'pip' considers all of its requirements. +# +# Not appending this to PIP_CONSTRAINT, because we don't want the torch '--extra-index-url' +# to leak outside of this script into other 'pip {download,install}'' calls. +rapids-dependency-file-generator \ + --output requirements \ + --file-key "torch_only" \ + --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION};dependencies=${RAPIDS_DEPENDENCIES};require_gpu=true" \ +| tee ./torch-constraints.txt + +rapids-pip-retry download \ + --isolated \ + --prefer-binary \ + --no-deps \ + -d "${TORCH_WHEEL_DIR}" \ + --constraint "${PIP_CONSTRAINT}" \ + --constraint ./torch-constraints.txt \ + 'torch' diff --git a/ci/run_cugraph_pyg_pytests.sh b/ci/run_cugraph_pyg_pytests.sh index 4431a013..da255e71 100755 --- a/ci/run_cugraph_pyg_pytests.sh +++ b/ci/run_cugraph_pyg_pytests.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 set -euo pipefail @@ -7,7 +7,7 @@ set -euo pipefail # Support invoking run_cugraph_pyg_pytests.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cugraph-pyg/cugraph_pyg -pytest --cache-clear --benchmark-disable "$@" . +pytest -rs --cache-clear --benchmark-disable "$@" . # Used to skip certain examples in CI due to memory limitations export CI=true diff --git a/ci/run_pylibwholegraph_pytests.sh b/ci/run_pylibwholegraph_pytests.sh index d9c858e1..805698d0 100755 --- a/ci/run_pylibwholegraph_pytests.sh +++ b/ci/run_pylibwholegraph_pytests.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 set -euo pipefail @@ -7,4 +7,4 @@ set -euo pipefail # Support invoking run_pytests.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/pylibwholegraph/pylibwholegraph/ -pytest --cache-clear --forked --import-mode=append "$@" tests +pytest -rs --cache-clear --forked --import-mode=append "$@" tests diff --git a/ci/test_wheel_cugraph-pyg.sh b/ci/test_wheel_cugraph-pyg.sh index 80148048..54c425d1 100755 --- a/ci/test_wheel_cugraph-pyg.sh +++ b/ci/test_wheel_cugraph-pyg.sh @@ -15,12 +15,30 @@ LIBWHOLEGRAPH_WHEELHOUSE=$(RAPIDS_PY_WHEEL_NAME="libwholegraph_${RAPIDS_PY_CUDA_ PYLIBWHOLEGRAPH_WHEELHOUSE=$(rapids-download-from-github "$(rapids-package-name "wheel_python" pylibwholegraph --stable --cuda "$RAPIDS_CUDA_VERSION")") CUGRAPH_PYG_WHEELHOUSE=$(RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-github python) -CUDA_MAJOR="${RAPIDS_CUDA_VERSION%%.*}" +# generate constraints (possibly pinning to oldest support versions of dependencies) +rapids-generate-pip-constraints test_cugraph_pyg "${PIP_CONSTRAINT}" -if [[ "${CUDA_MAJOR}" == "12" ]]; then - PYTORCH_INDEX="https://download.pytorch.org/whl/cu126" +PIP_INSTALL_ARGS=( + --prefer-binary + --constraint "${PIP_CONSTRAINT}" + --extra-index-url 'https://pypi.nvidia.com' + "${LIBWHOLEGRAPH_WHEELHOUSE}"/*.whl + "$(echo "${PYLIBWHOLEGRAPH_WHEELHOUSE}"/pylibwholegraph_"${RAPIDS_PY_CUDA_SUFFIX}"*.whl)" + "$(echo "${CUGRAPH_PYG_WHEELHOUSE}"/cugraph_pyg_"${RAPIDS_PY_CUDA_SUFFIX}"*.whl)[test]" +) + +# ensure a CUDA variant of 'torch' is used (if one is available) +TORCH_WHEEL_DIR="$(mktemp -d)" +./ci/download-torch-wheels.sh "${TORCH_WHEEL_DIR}" + +# 'cugraph-pyg' is still expected to be importable +# and testable in an environment where 'torch' isn't installed. +torch_downloaded=true +if [ -z "$(ls -A ${TORCH_WHEEL_DIR} 2>/dev/null)" ]; then + rapids-echo-stderr "No 'torch' wheels downloaded." + torch_downloaded=false else - PYTORCH_INDEX="https://download.pytorch.org/whl/cu130" + PIP_INSTALL_ARGS+=("${TORCH_WHEEL_DIR}"/torch-*.whl) fi # notes: @@ -30,12 +48,7 @@ fi # its dependencies are available from pypi.org # rapids-pip-retry install \ - -v \ - --extra-index-url "${PYTORCH_INDEX}" \ - --extra-index-url 'https://pypi.nvidia.com' \ - "${LIBWHOLEGRAPH_WHEELHOUSE}"/*.whl \ - "$(echo "${PYLIBWHOLEGRAPH_WHEELHOUSE}"/pylibwholegraph_"${RAPIDS_PY_CUDA_SUFFIX}"*.whl)" \ - "$(echo "${CUGRAPH_PYG_WHEELHOUSE}"/cugraph_pyg_"${RAPIDS_PY_CUDA_SUFFIX}"*.whl)[test]" + "${PIP_INSTALL_ARGS[@]}" # RAPIDS_DATASET_ROOT_DIR is used by test scripts export RAPIDS_DATASET_ROOT_DIR="$(realpath datasets)" @@ -47,5 +60,34 @@ popd # Enable legacy behavior of torch.load for examples relying on ogb export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 -rapids-logger "pytest cugraph-pyg (single GPU)" +if [[ "${torch_downloaded}" == "true" ]]; then + # TODO: remove this when RAPIDS wheels and 'torch' CUDA wheels have compatible package requirements + # + # * https://github.com/rapidsai/cugraph/issues/5443 + # * https://github.com/rapidsai/build-planning/issues/257 + # * https://github.com/rapidsai/build-planning/issues/255 + # + CUDA_MAJOR="${RAPIDS_CUDA_VERSION%%.*}" + CUDA_MINOR=$(echo "${RAPIDS_CUDA_VERSION}" | cut -d'.' -f2) + if [[ "${CUDA_MAJOR}" == "13" ]]; then + pip install \ + --upgrade \ + "nvidia-nvjitlink>=${CUDA_MAJOR}.${CUDA_MINOR}" + fi + + # 'torch' is an optional dependency of 'cugraph-pyg'... confirm that it's actually + # installed here and that we've installed a package with CUDA support. + rapids-logger "Confirming that PyTorch is installed" + python -c "import torch; assert torch.cuda.is_available()" + + rapids-logger "pytest cugraph-pyg (single GPU, with 'torch')" + ./ci/run_cugraph_pyg_pytests.sh +fi + +rapids-logger "import cugraph-pyg (no 'torch')" +./ci/uninstall-torch-wheels.sh + +python -c "import cugraph_pyg; print(f'cugraph-pyg version: {cugraph_pyg.__version__}')" + +rapids-logger "pytest cugraph-pyg (no 'torch')" ./ci/run_cugraph_pyg_pytests.sh diff --git a/ci/test_wheel_pylibwholegraph.sh b/ci/test_wheel_pylibwholegraph.sh index c5513cc9..ec31f656 100755 --- a/ci/test_wheel_pylibwholegraph.sh +++ b/ci/test_wheel_pylibwholegraph.sh @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -set -e # abort the script on error -set -o pipefail # piped commands propagate their error -set -E # ERR traps are inherited by subcommands +set -euo pipefail # Delete system libnccl.so to ensure the wheel is used. # (but only do this in CI, to avoid breaking local dev environments) @@ -18,23 +16,68 @@ RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" LIBWHOLEGRAPH_WHEELHOUSE=$(RAPIDS_PY_WHEEL_NAME="libwholegraph_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-github cpp) PYLIBWHOLEGRAPH_WHEELHOUSE=$(rapids-download-from-github "$(rapids-package-name "wheel_python" pylibwholegraph --stable --cuda "$RAPIDS_CUDA_VERSION")") -# determine pytorch source -if [[ "${CUDA_MAJOR}" == "12" ]]; then - PYTORCH_INDEX="https://download.pytorch.org/whl/cu126" -else - PYTORCH_INDEX="https://download.pytorch.org/whl/cu130" -fi RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"} RAPIDS_COVERAGE_DIR=${RAPIDS_COVERAGE_DIR:-"${PWD}/coverage-results"} mkdir -p "${RAPIDS_TESTS_DIR}" "${RAPIDS_COVERAGE_DIR}" +# generate constraints (possibly pinning to oldest support versions of dependencies) +rapids-generate-pip-constraints test_pylibwholegraph "${PIP_CONSTRAINT}" + +PIP_INSTALL_ARGS=( + --prefer-binary + --constraint "${PIP_CONSTRAINT}" + "$(echo "${PYLIBWHOLEGRAPH_WHEELHOUSE}"/pylibwholegraph*.whl)[test]" + "${LIBWHOLEGRAPH_WHEELHOUSE}"/*.whl +) + +# ensure a CUDA variant of 'torch' is used (if one is available) +TORCH_WHEEL_DIR="$(mktemp -d)" +./ci/download-torch-wheels.sh "${TORCH_WHEEL_DIR}" + +# 'pylibwholegraph' is still expected to be importable +# and testable in an environment where 'torch' isn't installed. +torch_downloaded=true +if [ -z "$(ls -A ${TORCH_WHEEL_DIR} 2>/dev/null)" ]; then + rapids-echo-stderr "No 'torch' wheels downloaded." + torch_downloaded=false +else + PIP_INSTALL_ARGS+=("${TORCH_WHEEL_DIR}"/torch-*.whl) +fi + # echo to expand wildcard before adding `[extra]` requires for pip rapids-logger "Installing Packages" rapids-pip-retry install \ - --extra-index-url ${PYTORCH_INDEX} \ - "$(echo "${PYLIBWHOLEGRAPH_WHEELHOUSE}"/pylibwholegraph*.whl)[test]" \ - "${LIBWHOLEGRAPH_WHEELHOUSE}"/*.whl \ - 'torch>=2.3' + "${PIP_INSTALL_ARGS[@]}" + + +if [[ "${torch_downloaded}" == "true" ]]; then + # TODO: remove this when RAPIDS wheels and 'torch' CUDA wheels have compatible package requirements + # + # * https://github.com/rapidsai/cugraph/issues/5443 + # * https://github.com/rapidsai/build-planning/issues/257 + # * https://github.com/rapidsai/build-planning/issues/255 + # + CUDA_MAJOR="${RAPIDS_CUDA_VERSION%%.*}" + CUDA_MINOR=$(echo "${RAPIDS_CUDA_VERSION}" | cut -d'.' -f2) + if [[ "${CUDA_MAJOR}" == "13" ]]; then + pip install \ + --upgrade \ + "nvidia-nvjitlink>=${CUDA_MAJOR}.${CUDA_MINOR}" + fi + + # 'torch' is an optional dependency of 'pylibwholegraph'... confirm that it's actually + # installed here and that we've installed a package with CUDA support. + rapids-logger "Confirming that PyTorch is installed" + python -c "import torch; assert torch.cuda.is_available()" + + rapids-logger "pytest pylibwholegraph (with 'torch')" + ./ci/run_pylibwholegraph_pytests.sh +fi + +rapids-logger "import pylibwholegraph (no 'torch')" +./ci/uninstall-torch-wheels.sh + +python -c "import pylibwholegraph; print(f'pylibwholegraph version: {pylibwholegraph.__version__}')" -rapids-logger "pytest pylibwholegraph" +rapids-logger "pytest pylibwholegraph (no 'torch')" ./ci/run_pylibwholegraph_pytests.sh diff --git a/ci/uninstall-torch-wheels.sh b/ci/uninstall-torch-wheels.sh new file mode 100755 index 00000000..3590bdc0 --- /dev/null +++ b/ci/uninstall-torch-wheels.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +pip uninstall --yes 'torch' + +# 'pytest' leaves behind some pycache files in site-packages/torch that make 'import torch' +# seem to "work" even though there's not really a package there, leading to errors like +# "module 'torch' has no attribute 'distributed'" +# +# For the sake of testing, just fully delete 'torch' from site-packages to simulate an environment +# where it was never installed. +SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") +rm -rf "${SITE_PACKAGES}/torch" diff --git a/ci/validate_wheel.sh b/ci/validate_wheel.sh index 42d0a8bf..88ba85aa 100755 --- a/ci/validate_wheel.sh +++ b/ci/validate_wheel.sh @@ -43,3 +43,22 @@ rapids-logger "validate packages with 'twine'" twine check \ --strict \ "$(echo ${wheel_dir_relative_path}/*.whl)" + +rapids-logger "validating that the wheel doesn't depend on 'torch' (even in an extra)" +WHEEL_FILE="$(echo ${wheel_dir_relative_path}/*.whl)" + +# NOTE: group of specifiers after 'torch' to avoid a false positive like 'torch-geometric' +# Use '|| true' so grep not finding any matches (exit 1) does not kill the script under set -e +unzip -p "${WHEEL_FILE}" '*.dist-info/METADATA' \ +| grep -E '^Requires-Dist:.*torch[><=!~ ]+.*' \ +| tee matches.txt || true + +if [[ -s ./matches.txt ]]; then + echo -n "Wheel '${WHEEL_FILE}' appears to depend on 'torch'. Remove that dependency. " + echo -n "We prefer to not declare a 'torch' dependency and allow it to be managed separately, " + echo "to ensure tight control over the variants installed (including for DLFW builds)." + exit 1 +else + echo "No dependency on 'torch' found" + exit 0 +fi diff --git a/dependencies.yaml b/dependencies.yaml index 66c33742..bc277925 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -162,7 +162,6 @@ files: table: project.optional-dependencies key: test includes: - - depends_on_pytorch - depends_on_cuml - depends_on_cugraph - depends_on_ogb @@ -180,6 +179,10 @@ files: - depends_on_pyg - depends_on_pytorch - test_python_common + torch_only: + output: none + includes: + - depends_on_pytorch channels: - rapidsai-nightly - rapidsai @@ -348,7 +351,6 @@ dependencies: - *packaging - pytest-forked - scipy - depends_on_pytorch: depends_on_pytorch: specific: # conda: choose between GPU and CPU-only pytorch @@ -383,37 +385,54 @@ dependencies: # Default to falling back to whatever 'pytorch' is pulled in via cugraph-pyg's dependencies. - matrix: packages: - - output_types: [requirements] + # wheels: handle GPU vs. CPU and version pinning together + # + # The 'pytorch.org' indices referenced in --extra-index-url below host CPU-only variants too, + # so requirements like '>=' are not safe. + # + # Using '==' and a version with the CUDA specifier like '+cu130' is the most reliable way to ensure + # the packages we want are pulled (at the expense of needing to maintain this list). + # + # 'torch' tightly pins wheels to a single {major}.{minor} CTK version. + # + # This list only contains entries exactly matching CUDA {major}.{minor} that we test in RAPIDS CI, + # to ensure a loud error alerts us to the need to update this list (or CI scripts) when new + # CTKs are added to the support matrix. + - output_types: requirements matrices: + # avoid pulling in 'torch' in places like DLFW builds that prefer to install it other ways - matrix: no_pytorch: "true" packages: + # matrices below ensure CUDA 'torch' packages are used - matrix: - cuda: "12.*" - packages: - - --extra-index-url=https://download.pytorch.org/whl/cu126 - - matrix: - cuda: "13.*" - packages: - - --extra-index-url=https://download.pytorch.org/whl/cu130 - - matrix: + cuda: "12.9" + dependencies: "oldest" + require_gpu: "true" packages: - - output_types: [requirements, pyproject] - matrices: + - &torch_cu129_index --extra-index-url=https://download.pytorch.org/whl/cu129 + - torch==2.8.0+cu129 - matrix: - no_pytorch: "true" + cuda: "12.9" + require_gpu: "true" packages: + - *torch_cu129_index + - torch==2.10.0+cu129 - matrix: - cuda: "12.*" + cuda: "13.0" + dependencies: "oldest" + require_gpu: "true" packages: - - torch>=2.3 + - &torch_index_cu13 --extra-index-url=https://download.pytorch.org/whl/cu130 + - torch==2.8.0+cu130 - matrix: - cuda: "13.*" + cuda: "13.0" + require_gpu: "true" packages: - - &pytorch_pip torch>=2.9.0 + - *torch_index_cu13 + - torch==2.10.0+cu130 - matrix: packages: - - *pytorch_pip depends_on_nccl: common: - output_types: conda diff --git a/pyproject.toml b/pyproject.toml index fbe24671..5662e9f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 [tool.ruff] @@ -10,5 +10,29 @@ exclude = [ [tool.ruff.lint] ignore = [ # whitespace before : - "E203", + "E203" +] +select = [ + # (pycodestyle) + "E4", + "E7", + "E9", + # (pyflakes) + "F", + # (flake8-tidy-imports) banned-api + "TID251" +] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"torch".msg = "Use 'import_optional(\"torch\")' in library code, or the 'torch' pytest fixture in test code (see conftest.py), instead of 'import torch'." + +[tool.ruff.lint.per-file-ignores] +# allow importing 'torch' directly in cugraph-pyg examples +"python/cugraph-pyg/cugraph_pyg/examples/*" = [ + "TID251" +] + +# allow importing 'torch' directly in pylibwholegraph examples +"python/pylibwholegraph/examples/*" = [ + "TID251" ] diff --git a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py index ba2081ca..fd645cb7 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import warnings @@ -18,9 +18,12 @@ wgth = import_optional("pylibwholegraph.torch") +# If 'torch_geometric' is available but 'torch' is not, accessing +# 'torch_geometric.data.GraphStore' will fail because `torch_geometric` +# unconditionally imports 'torch'... so need to check that both are available. class FeatureStore( object - if isinstance(torch_geometric, MissingModule) + if (isinstance(torch_geometric, MissingModule) or isinstance(torch, MissingModule)) else torch_geometric.data.FeatureStore ): """ diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py index eada6a61..7a522912 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -30,9 +30,12 @@ ] +# If 'torch_geometric' is available but 'torch' is not, accessing +# 'torch_geometric.data.GraphStore' will fail because `torch_geometric` +# unconditionally imports 'torch'... so need to check that both are available. class GraphStore( object - if isinstance(torch_geometric, MissingModule) + if (isinstance(torch_geometric, MissingModule) or isinstance(torch, MissingModule)) else torch_geometric.data.GraphStore ): """ diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py index f8bb1f6e..53644afa 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py @@ -12,10 +12,8 @@ import cupy import pylibcugraph -torch_geometric = import_optional("torch_geometric") - torch = import_optional("torch") -HeteroSamplerOutput = torch_geometric.sampler.base.HeteroSamplerOutput +torch_geometric = import_optional("torch_geometric") def verify_metadata(metadata: Optional[Dict[str, Union[str, Tuple[str, str, str]]]]): diff --git a/python/cugraph-pyg/cugraph_pyg/tensor/dist_matrix.py b/python/cugraph-pyg/cugraph_pyg/tensor/dist_matrix.py index 2c811245..c9560ff4 100644 --- a/python/cugraph-pyg/cugraph_pyg/tensor/dist_matrix.py +++ b/python/cugraph-pyg/cugraph_pyg/tensor/dist_matrix.py @@ -18,14 +18,14 @@ def __init__( self, src: Optional[ Union[ - Tuple[torch.Tensor, torch.Tensor], + Tuple["torch.Tensor", "torch.Tensor"], Tuple[DistTensor, DistTensor], str, List[str], ] ] = None, shape: Optional[Union[list, tuple]] = None, - dtype: Optional[torch.dtype] = None, + dtype: Optional["torch.dtype"] = None, device: Optional[Literal["cpu", "cuda"]] = "cpu", backend: Optional[Literal["nccl", "vmm"]] = "nccl", format: Optional[Literal["csc", "coo"]] = "coo", @@ -82,8 +82,8 @@ def __init__( def __setitem__( self, - idx: Union[torch.Tensor, slice], - val: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + idx: Union["torch.Tensor", slice], + val: Union["torch.Tensor", tuple["torch.Tensor", "torch.Tensor"]], ): if isinstance(idx, slice): size = self._col.shape[0] @@ -106,7 +106,7 @@ def __setitem__( self._col[idx] = val[0] self._row[idx] = val[1] - def __getitem__(self, idx: torch.Tensor) -> torch.Tensor: + def __getitem__(self, idx: "torch.Tensor") -> "torch.Tensor": if self._format != "coo": raise ValueError("Getting is currently only supported for COO format") if idx.dim() != 1: @@ -114,11 +114,11 @@ def __getitem__(self, idx: torch.Tensor) -> torch.Tensor: return torch.stack([self._col[idx], self._row[idx]]) - def get_local_tensor(self) -> Tuple[torch.Tensor, torch.Tensor]: + def get_local_tensor(self) -> Tuple["torch.Tensor", "torch.Tensor"]: return (self._col.get_local_tensor(), self._row.get_local_tensor()) @property - def local_col(self) -> torch.Tensor: + def local_col(self) -> "torch.Tensor": world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() @@ -134,7 +134,7 @@ def local_col(self) -> torch.Tensor: return self._col[ix] @property - def local_row(self) -> torch.Tensor: + def local_row(self) -> "torch.Tensor": world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() @@ -150,7 +150,7 @@ def local_row(self) -> torch.Tensor: return self._row[ix] @property - def local_coo(self) -> torch.Tensor: + def local_coo(self) -> "torch.Tensor": return torch.stack([self.local_col, self.local_row]) @property @@ -158,5 +158,5 @@ def shape(self) -> Tuple[int, int]: return (self._col.shape[0], self._row.shape[0]) @property - def dtype(self) -> torch.dtype: + def dtype(self) -> "torch.dtype": return self._col.dtype diff --git a/python/cugraph-pyg/cugraph_pyg/tensor/utils.py b/python/cugraph-pyg/cugraph_pyg/tensor/utils.py index d8780000..fb994bc9 100644 --- a/python/cugraph-pyg/cugraph_pyg/tensor/utils.py +++ b/python/cugraph-pyg/cugraph_pyg/tensor/utils.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 from typing import Union, List @@ -96,7 +96,7 @@ def create_wg_dist_tensor( def create_wg_dist_tensor_from_files( file_list: List[str], shape: list, - dtype: torch.dtype, + dtype: "torch.dtype", location: str = "cpu", partition_book: Union[List[int], None] = None, backend: str = "nccl", diff --git a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py index f480aeb8..81864fbd 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py @@ -1,9 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest import os -import torch from pylibcugraph.comms import ( @@ -29,7 +28,13 @@ @pytest.fixture(scope="module") -def single_pytorch_worker(): +def torch(): + """Pass this to any test case that needs 'torch' to be installed""" + return pytest.importorskip("torch") + + +@pytest.fixture(scope="module") +def single_pytorch_worker(torch): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" os.environ["LOCAL_RANK"] = "0" @@ -44,14 +49,14 @@ def single_pytorch_worker(): @pytest.fixture -def basic_pyg_graph_1(): +def basic_pyg_graph_1(torch): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) size = (4, 4) return edge_index, size @pytest.fixture -def basic_pyg_graph_2(): +def basic_pyg_graph_2(torch): edge_index = torch.tensor( [ [0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9], @@ -63,7 +68,7 @@ def basic_pyg_graph_2(): @pytest.fixture -def sample_pyg_hetero_data(): +def sample_pyg_hetero_data(torch): torch.manual_seed(12345) raw_data_dict = { "v0": torch.randn(6, 3), diff --git a/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py b/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py index f64bee55..fc29c0a8 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest @@ -53,18 +53,19 @@ def test_feature_store_basic_api(single_pytorch_worker): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.sg @pytest.mark.parametrize( - "dtype", + "dtype_name", [ - torch.float32, - torch.float16, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.float64, + "float32", + "float16", + "int8", + "int16", + "int32", + "int64", + "float64", ], ) -def test_feature_store_basic_api_types(single_pytorch_worker, dtype): +def test_feature_store_basic_api_types(single_pytorch_worker, dtype_name, torch): + dtype = getattr(torch, dtype_name) features = torch.arange(0, 2000) features = features.reshape((features.numel() // 100, 100)).to(dtype) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/tensor/test_dist_matrix_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/tensor/test_dist_matrix_mg.py index 0ef4ca00..ae2d050e 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/tensor/test_dist_matrix_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/tensor/test_dist_matrix_mg.py @@ -1,10 +1,9 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import os import pytest -import torch from cugraph_pyg.tensor import DistMatrix from pylibwholegraph.torch.initialize import init as wm_init @@ -13,6 +12,7 @@ def run_test_dist_matrix_creation(rank, world_size, device): """Test basic DistMatrix creation from tensors""" + torch = pytest.importorskip("torch") torch.cuda.set_device(rank) os.environ["MASTER_ADDR"] = "localhost" @@ -55,6 +55,7 @@ def run_test_dist_matrix_creation(rank, world_size, device): def run_test_dist_matrix_empty_creation(rank, world_size, device): """Test DistMatrix creation with empty initialization""" + torch = pytest.importorskip("torch") torch.cuda.set_device(rank) os.environ["MASTER_ADDR"] = "localhost" @@ -102,6 +103,7 @@ def run_test_dist_matrix_empty_creation(rank, world_size, device): def run_test_dist_matrix_invalid_cases(rank, world_size, device): """Test DistMatrix creation with invalid cases""" + torch = pytest.importorskip("torch") torch.cuda.set_device(rank) os.environ["MASTER_ADDR"] = "localhost" @@ -138,7 +140,7 @@ def run_test_dist_matrix_invalid_cases(rank, world_size, device): @pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_dist_matrix(device): +def test_dist_matrix(device, torch): """Run all DistMatrix tests""" world_size = torch.cuda.device_count() diff --git a/python/cugraph-pyg/cugraph_pyg/utils/imports.py b/python/cugraph-pyg/cugraph_pyg/utils/imports.py index b4e4df42..270b2eca 100644 --- a/python/cugraph-pyg/cugraph_pyg/utils/imports.py +++ b/python/cugraph-pyg/cugraph_pyg/utils/imports.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 from packaging.requirements import Requirement from importlib import import_module +from importlib.util import find_spec def package_available(requirement: str) -> bool: @@ -39,6 +40,18 @@ def __getattr__(self, attr): raise RuntimeError(f"This feature requires the {self.name} package/module") +class FoundModule: + def __init__(self, mod): + self.mod = mod + self.imported = False + + def __getattr__(self, attr): + if not self.imported: + self.mod = import_module(self.mod) + self.imported = True + return getattr(self.mod, attr) + + def import_optional(mod, default_mod_class=MissingModule): """ import the "optional" module 'mod' and return the module object or object. @@ -80,7 +93,15 @@ def import_optional(mod, default_mod_class=MissingModule): >> """ + # this try-except is necessary to handle dotted imports, + # like `import_optional("torch.autograd")` + mod_found = False try: - return import_module(mod) - except ModuleNotFoundError: + mod_found = find_spec(mod) is not None + except ImportError: + mod_found = False + + if mod_found: + return FoundModule(mod) + else: return default_mod_class(mod_name=mod) diff --git a/python/cugraph-pyg/pyproject.toml b/python/cugraph-pyg/pyproject.toml index 5b102513..1dabec2e 100644 --- a/python/cugraph-pyg/pyproject.toml +++ b/python/cugraph-pyg/pyproject.toml @@ -58,7 +58,6 @@ test = [ "pytest-cov", "pytest-xdist", "sentence-transformers", - "torch>=2.9.0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [tool.setuptools.dynamic] diff --git a/python/pylibwholegraph/pylibwholegraph/_doctor_check.py b/python/pylibwholegraph/pylibwholegraph/_doctor_check.py index 33ac107d..a76e8483 100644 --- a/python/pylibwholegraph/pylibwholegraph/_doctor_check.py +++ b/python/pylibwholegraph/pylibwholegraph/_doctor_check.py @@ -27,7 +27,7 @@ def pylibwholegraph_smoke_check(**kwargs): ) try: - import torch + import torch # noqa: TID251 assert torch.cuda.is_available() diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py index 3197b53f..bbfd2163 100644 --- a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py +++ b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch import numpy as np +import pytest import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack from packaging import version @@ -17,6 +17,7 @@ def gen_csr_format_from_dense_matrix( csr_col_dtype, weight_dtype, ): + torch = pytest.importorskip("torch") row_num = matrix_tensor.shape[0] col_num = matrix_tensor.shape[1] assert row_num == graph_node_count @@ -44,11 +45,13 @@ def gen_csr_format_from_dense_matrix( def gen_csr_graph( graph_node_count, graph_edge_count, - neighbor_node_count=None, - csr_row_dtype=torch.int64, - csr_col_dtype=torch.int32, - weight_dtype=torch.float32, + *, + neighbor_node_count, + csr_row_dtype, + csr_col_dtype, + weight_dtype, ): + torch = pytest.importorskip("torch") if neighbor_node_count is None: neighbor_node_count = graph_node_count all_count = graph_node_count * neighbor_node_count @@ -95,6 +98,7 @@ def host_sample_all_neighbors( col_id_dtype, total_sample_count, ): + torch = pytest.importorskip("torch") output_dest_tensor = torch.empty((total_sample_count,), dtype=col_id_dtype) output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32) output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64) @@ -133,6 +137,7 @@ def copy_host_1D_tensor_to_wholememory( def host_get_sample_offset_tensor(host_csr_row_ptr, center_nodes, max_sample_count): + torch = pytest.importorskip("torch") center_nodes_count = center_nodes.size(0) output_sample_offset_tensor = torch.empty( (center_nodes_count + 1,), dtype=torch.int32 diff --git a/python/pylibwholegraph/pylibwholegraph/tests/conftest.py b/python/pylibwholegraph/pylibwholegraph/tests/conftest.py new file mode 100644 index 00000000..4032ca07 --- /dev/null +++ b/python/pylibwholegraph/pylibwholegraph/tests/conftest.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + + +@pytest.fixture(scope="function") +def torch(): + """Pass this to any test case that needs 'torch' to be installed""" + return pytest.importorskip("torch") diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_binding.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_binding.py index 7e11b731..366d03e3 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_binding.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_binding.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest @@ -6,7 +6,6 @@ from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack -import torch # Run with: @@ -14,6 +13,7 @@ def single_test_case(wm_comm, mt, ml, malloc_size, granularity): + torch = pytest.importorskip("torch") world_rank = wm_comm.get_rank() print("Rank=%d testing mt=%s, ml=%s" % (world_rank, mt, ml)) h = wmb.malloc(malloc_size, wm_comm, mt, ml, granularity) @@ -105,7 +105,7 @@ def routine_func(world_rank: int, world_size: int): wmb.finalize() -def test_dlpack(): +def test_dlpack(torch): gpu_count = wmb.fork_get_gpu_count() assert gpu_count > 0 multiprocess_run(gpu_count, routine_func) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py index c9419c75..bf093dca 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest @@ -7,13 +7,11 @@ from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack from pylibwholegraph.test_utils.test_comm import random_partition -import torch import numpy as np import os import random from functools import partial - gpu_count = None @@ -49,6 +47,7 @@ def load_routine_func( round_robin_size=0, entry_partition=None, ): + torch = pytest.importorskip("torch") wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size ) @@ -181,6 +180,7 @@ def test_wholememory_load( storage_offset, round_robin_size, partition_method, + torch, ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( @@ -294,6 +294,7 @@ def store_routine_func( storage_offset, entry_partition, ): + torch = pytest.importorskip("torch") (wm_comm, _) = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size ) @@ -347,6 +348,7 @@ def test_wholememory_store( embedding_stride, storage_offset, partition_method, + torch, ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py index 648f7dc8..0e53c209 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pylibwholegraph.binding.wholememory_binding as wmb @@ -107,7 +107,7 @@ def routine_func(world_rank: int, world_size: int): wmb.finalize() -def test_wholememory_tensor(): +def test_wholememory_tensor(torch): gpu_count = wmb.fork_get_gpu_count() assert gpu_count > 0 multiprocess_run(gpu_count, routine_func) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py index 821cf457..07fb409a 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py @@ -1,13 +1,13 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest -import torch from pylibwholegraph.test_utils.test_comm import gen_csr_graph import pylibwholegraph.torch.graph_ops as wg_ops def host_add_csr_self_loop(csr_row_ptr_tensor, csr_col_ptr_tensor): + torch = pytest.importorskip("torch") row_num = csr_row_ptr_tensor.shape[0] - 1 edge_num = csr_col_ptr_tensor.shape[0] output_csr_row_ptr_tensor = torch.empty( @@ -28,6 +28,7 @@ def host_add_csr_self_loop(csr_row_ptr_tensor, csr_col_ptr_tensor): def routine_func(**kwargs): + torch = pytest.importorskip("torch") target_node_count = kwargs["target_node_count"] neighbor_node_count = kwargs["neighbor_node_count"] edge_num = kwargs["edge_num"] @@ -35,9 +36,10 @@ def routine_func(**kwargs): csr_row_ptr_tensor, csr_col_ptr_tensor, _ = gen_csr_graph( target_node_count, edge_num, - neighbor_node_count, + neighbor_node_count=neighbor_node_count, csr_row_dtype=torch.int32, csr_col_dtype=torch.int32, + weight_dtype=torch.float32, ) csr_row_ptr_tensor_cuda = csr_row_ptr_tensor.cuda() csr_col_ptr_tensor_cuda = csr_col_ptr_tensor.cuda() @@ -58,7 +60,7 @@ def routine_func(**kwargs): @pytest.mark.parametrize("target_node_count", [101, 113]) @pytest.mark.parametrize("neighbor_node_count", [157, 1987]) @pytest.mark.parametrize("edge_num", [1001, 2305]) -def test_add_csr_self_loop(target_node_count, neighbor_node_count, edge_num): +def test_add_csr_self_loop(target_node_count, neighbor_node_count, edge_num, torch): gpu_count = torch.cuda.device_count() assert gpu_count > 0 routine_func( diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py index e325ef51..e94c1a9a 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py @@ -1,12 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest -import torch import pylibwholegraph.torch.graph_ops as wg_ops def host_neighbor_raw_to_unique(unique_node_tensor, neighbor_node_tensor): + torch = pytest.importorskip("torch") output_neighbor_raw_to_unique = torch.empty( (neighbor_node_tensor.size(0)), dtype=torch.int32 ) @@ -19,6 +19,7 @@ def host_neighbor_raw_to_unique(unique_node_tensor, neighbor_node_tensor): def routine_func(**kwargs): + torch = pytest.importorskip("torch") target_node_count = kwargs["target_node_count"] neighbor_node_count = kwargs["neighbor_node_count"] target_node_dtype = kwargs["target_node_dtype"] @@ -73,19 +74,20 @@ def routine_func(**kwargs): @pytest.mark.parametrize("target_node_count", [10, 113]) @pytest.mark.parametrize("neighbor_node_count", [104, 1987]) -@pytest.mark.parametrize("target_node_dtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("target_node_dtype", ["int32", "int64"]) @pytest.mark.parametrize("need_neighbor_raw_to_unique", [True, False]) def test_append_unique( target_node_count, neighbor_node_count, target_node_dtype, need_neighbor_raw_to_unique, + torch, ): gpu_count = torch.cuda.device_count() assert gpu_count > 0 routine_func( target_node_count=target_node_count, neighbor_node_count=neighbor_node_count, - target_node_dtype=target_node_dtype, + target_node_dtype=getattr(torch, target_node_dtype), need_neighbor_raw_to_unique=need_neighbor_raw_to_unique, ) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index a3b8849b..0395d2a6 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -1,20 +1,20 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 +import pytest import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack from pylibwholegraph.test_utils.test_comm import random_partition -import torch import pylibwholegraph.torch.wholememory_ops as wm_ops - # PYTHONPATH=../:$PYTHONPATH python3 -m pytest \ # ../tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py -s def gen_int_embedding(indice_tensor, embedding_dim, output_type): + torch = pytest.importorskip("torch") if embedding_dim == 0: embedding_dim = 1 # unsqueeze 2D for input (2D is required for scatter op) indice_count = indice_tensor.shape[0] @@ -41,6 +41,7 @@ def scatter_gather_test_cast( use_python_binding=True, entry_partition=None, ): + torch = pytest.importorskip("torch") world_rank = wm_comm.get_rank() world_size = wm_comm.get_size() print( @@ -173,7 +174,7 @@ def routine_func(world_rank: int, world_size: int): wmb.finalize() -def test_wholegraph_gather_scatter(): +def test_wholegraph_gather_scatter(torch): gpu_count = wmb.fork_get_gpu_count() assert gpu_count > 0 multiprocess_run(gpu_count, routine_func) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py index c436e9d1..75f1cd9a 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py @@ -1,11 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm -import torch from functools import partial from pylibwholegraph.test_utils.test_comm import ( gen_csr_graph, @@ -21,6 +20,7 @@ def unweighte_sample_without_replacement_base(random_values, M, N): + torch = pytest.importorskip("torch") a = torch.empty((M,), dtype=torch.int32) Q = torch.arange(N, dtype=torch.int32) for i in range(M): @@ -39,6 +39,7 @@ def host_unweighted_sample_without_replacement_func( max_sample_count, random_seed, ): + torch = pytest.importorskip("torch") output_dest_tensor = torch.empty((total_sample_count,), dtype=col_id_dtype) output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32) output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64) @@ -211,6 +212,7 @@ def host_unweighted_sample_without_replacement( def routine_func(world_rank: int, world_size: int, **kwargs): + torch = pytest.importorskip("torch") wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size ) @@ -351,7 +353,7 @@ def routine_func(world_rank: int, world_size: int, **kwargs): @pytest.mark.parametrize("graph_edge_count", [1043]) @pytest.mark.parametrize("max_sample_count", [11, -1]) @pytest.mark.parametrize("center_node_count", [13]) -@pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("center_node_dtype", ["int32", "int64"]) @pytest.mark.parametrize("col_id_dtype", [0, 1]) @pytest.mark.parametrize("wholememory_location", ([0, 1])) @pytest.mark.parametrize("wholememory_type", ([0, 1, 2])) @@ -368,6 +370,7 @@ def test_wholegraph_unweighted_sample( wholememory_type, need_center_local_output, need_edge_output, + torch, ): gpu_count = wmb.fork_get_gpu_count() assert gpu_count > 0 @@ -375,7 +378,12 @@ def test_wholegraph_unweighted_sample( if col_id_dtype == wmb.WholeMemoryDataType.DtInt64: csr_col_dtype = torch.int64 host_csr_row_ptr, host_csr_col_ptr, _ = gen_csr_graph( - graph_node_count, graph_edge_count, csr_col_dtype=csr_col_dtype + graph_node_count, + graph_edge_count, + neighbor_node_count=None, + csr_row_dtype=torch.int64, + csr_col_dtype=csr_col_dtype, + weight_dtype=torch.float32, ) routine_func_partial = partial( routine_func, @@ -385,7 +393,7 @@ def test_wholegraph_unweighted_sample( graph_edge_count=graph_edge_count, max_sample_count=max_sample_count, center_node_count=center_node_count, - center_node_dtype=center_node_dtype, + center_node_dtype=getattr(torch, center_node_dtype), col_id_dtype=col_id_dtype, wholememory_location=wholememory_location, wholememory_type=wholememory_type, diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py index 10ef139e..7e473f60 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py @@ -1,11 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm import pylibwholegraph.binding.wholememory_binding as wmb -import torch import random from functools import partial from pylibwholegraph.test_utils.test_comm import ( @@ -32,6 +31,7 @@ def host_weighted_sample_without_replacement_func( max_sample_count, random_seed, ): + torch = pytest.importorskip("torch") output_dest_tensor = torch.empty((total_sample_count,), dtype=col_id_dtype) output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32) output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64) @@ -116,6 +116,7 @@ def host_weighted_sample_without_replacement( col_id_dtype, random_seed, ): + torch = pytest.importorskip("torch") center_nodes_count = center_nodes.size(0) output_sample_offset_tensor = host_get_sample_offset_tensor( host_csr_row_ptr, center_nodes, max_sample_count @@ -166,6 +167,7 @@ def host_weighted_sample_without_replacement( def routine_func(world_rank: int, world_size: int, **kwargs): + torch = pytest.importorskip("torch") wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size ) @@ -353,7 +355,7 @@ def routine_func(world_rank: int, world_size: int, **kwargs): @pytest.mark.parametrize("graph_edge_count", [1043]) @pytest.mark.parametrize("max_sample_count", [11]) @pytest.mark.parametrize("center_node_count", [13]) -@pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("center_node_dtype", ["int32", "int64"]) @pytest.mark.parametrize("col_id_dtype", [0, 1]) @pytest.mark.parametrize("csr_weight_dtype", [2, 3]) @pytest.mark.parametrize("wholememory_location", ([0, 1])) @@ -372,6 +374,7 @@ def test_wholegraph_weighted_sample( wholememory_type, need_center_local_output, need_edge_output, + torch, ): gpu_count = wmb.fork_get_gpu_count() assert gpu_count > 0 @@ -379,7 +382,12 @@ def test_wholegraph_weighted_sample( if col_id_dtype == 1: csr_col_dtype = torch.int64 host_csr_row_ptr, host_csr_col_ptr, host_csr_weight_ptr = gen_csr_graph( - graph_node_count, graph_edge_count, csr_col_dtype=csr_col_dtype + graph_node_count, + graph_edge_count, + neighbor_node_count=None, + csr_row_dtype=torch.int64, + csr_col_dtype=csr_col_dtype, + weight_dtype=torch.float32, ) routine_func_partial = partial( routine_func, @@ -390,7 +398,7 @@ def test_wholegraph_weighted_sample( graph_edge_count=graph_edge_count, max_sample_count=max_sample_count, center_node_count=center_node_count, - center_node_dtype=center_node_dtype, + center_node_dtype=getattr(torch, center_node_dtype), col_id_dtype=col_id_dtype, csr_weight_dtype=csr_weight_dtype, wholememory_location=wholememory_location, diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py index a4726383..0e3310c2 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py @@ -1,9 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pytest import pylibwholegraph.binding.wholememory_binding as wmb -import torch from pylibwholegraph.torch.wholegraph_env import ( get_stream, get_wholegraph_env_fns, @@ -14,7 +13,7 @@ import time -def test_smoke(): +def test_smoke(torch): torch.cuda.set_device(0) output_len = 128 embed_dim = 10 @@ -55,7 +54,7 @@ def test_smoke(): assert wmb.py_get_wholememory_tensor_count() == 0 -def test_loop_memory(): +def test_loop_memory(torch): torch.cuda.set_device(0) embedding_dim = 1 output_len = 1 @@ -107,7 +106,7 @@ def test_loop_memory(): @pytest.mark.parametrize("output_len", list(range(1, 100, 17))) @pytest.mark.parametrize("embed_dim", list(range(1, 128, 23))) -def test_random_alloc(output_len, embed_dim): +def test_random_alloc(output_len, embed_dim, torch): torch.cuda.set_device(0) input_tensor = torch.rand((embed_dim,), device="cuda") indice_tensor = torch.arange(output_len, device="cuda") diff --git a/python/pylibwholegraph/pylibwholegraph/torch/comm.py b/python/pylibwholegraph/pylibwholegraph/torch/comm.py index 634473f7..85be715a 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/comm.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/comm.py @@ -1,10 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch -import torch.distributed as dist -import torch.utils.dlpack import pylibwholegraph.binding.wholememory_binding as wmb +from pylibwholegraph.utils.imports import import_optional from .utils import ( str_to_wmb_wholememory_distributed_backend_type, wholememory_distributed_backend_type_to_str, @@ -12,6 +10,8 @@ str_to_wmb_wholememory_location, ) +torch = import_optional("torch") + global_communicators = {} local_node_communicator = None local_device_communicator = None @@ -140,13 +140,13 @@ def create_group_communicator(group_size: int = -1, comm_stride: int = 1): :param comm_stride: Stride of each rank in each group :return: WholeMemoryCommunicator """ - world_size = dist.get_world_size() + world_size = torch.distributed.get_world_size() if group_size == -1: group_size = world_size strided_group_size = group_size * comm_stride assert world_size % strided_group_size == 0 strided_group_count = world_size // strided_group_size - world_rank = dist.get_rank() + world_rank = torch.distributed.get_rank() strided_group_idx = world_rank // strided_group_size idx_in_strided_group = world_rank % strided_group_size inner_group_idx = idx_in_strided_group % comm_stride @@ -161,7 +161,7 @@ def create_group_communicator(group_size: int = -1, comm_stride: int = 1): tmp_wm_uid = wmb.PyWholeMemoryUniqueID() uid_th = torch.utils.dlpack.from_dlpack(tmp_wm_uid.__dlpack__()) uid_th_cuda = uid_th.cuda() - dist.broadcast(uid_th_cuda, group_root_rank) + torch.distributed.broadcast(uid_th_cuda, group_root_rank) uid_th.copy_(uid_th_cuda.cpu()) if strided_group_idx == strided_group and inner_group_idx == inner_group: wm_uid_th = torch.utils.dlpack.from_dlpack(wm_uid.__dlpack__()) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/data_loader.py b/python/pylibwholegraph/pylibwholegraph/torch/data_loader.py index 35dd8e60..041c2d77 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/data_loader.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/data_loader.py @@ -1,20 +1,36 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import numpy as np -import torch -from torch.utils.data import Dataset +from pylibwholegraph.utils.imports import import_optional, MissingModule +torch = import_optional("torch") -class NodeClassificationDataset(Dataset): - def __init__(self, raw_dataset): - self.dataset = raw_dataset +# NOTE: using more specific 'import_optional()' than just 'torch' for import-time checks +# (e.g. those needed for defining base classes) can be helpful because 'torch' can appear +# to be available even after a 'pip uninstall torch' if any files are left behind in +# 'site-packages/torch'. +torch_utils_data = import_optional("torch.utils.data") - def __getitem__(self, index): - return self.dataset[index] - def __len__(self): - return len(self.dataset) +if not isinstance(torch_utils_data, MissingModule): + + class NodeClassificationDataset(torch_utils_data.Dataset): + def __init__(self, raw_dataset): + self.dataset = raw_dataset + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) +else: + + class NodeClassificationDataset: + def __init__(self, raw_dataset): + raise ModuleNotFoundError( + "NodeClassificationDataset requires 'torch.utils.data'. Install 'torch'." + ) def create_node_classification_datasets(data_and_label: dict): @@ -55,14 +71,14 @@ def get_train_dataloader( num_replicas: int = 1, num_workers: int = 0, ): - train_sampler = torch.utils.data.distributed.DistributedSampler( + train_sampler = torch_utils_data.distributed.DistributedSampler( train_dataset, num_replicas=num_replicas, rank=replica_id, shuffle=True, drop_last=True, ) - train_dataloader = torch.utils.data.DataLoader( + train_dataloader = torch_utils_data.DataLoader( train_dataset, batch_size=batch_size, num_workers=num_workers, @@ -76,10 +92,10 @@ def get_train_dataloader( def get_valid_test_dataloader( valid_test_dataset, batch_size: int, *, num_workers: int = 0 ): - valid_test_sampler = torch.utils.data.distributed.DistributedSampler( + valid_test_sampler = torch_utils_data.distributed.DistributedSampler( valid_test_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False ) - valid_test_dataloader = torch.utils.data.DataLoader( + valid_test_dataloader = torch_utils_data.DataLoader( valid_test_dataset, batch_size=batch_size, num_workers=num_workers, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py b/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py index e7990546..50dcaae0 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py @@ -1,9 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 +from pylibwholegraph.utils.imports import import_optional import os from argparse import ArgumentParser +torch = import_optional("torch") + class DistributedConfig(object): def __init__(self): @@ -281,10 +284,8 @@ def distributed_launch_spawn(args, main_func): ) ) - import torch.multiprocessing as mp - if distributed_config.local_size > 1: - mp.spawn( + torch.multiprocessing.spawn( main_spawn_routine, nprocs=distributed_config.local_size, args=(main_func, distributed_config), diff --git a/python/pylibwholegraph/pylibwholegraph/torch/dlpack_utils.py b/python/pylibwholegraph/pylibwholegraph/torch/dlpack_utils.py index 25f36bf3..74e1ae82 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/dlpack_utils.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/dlpack_utils.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch -import torch.utils.dlpack +from pylibwholegraph.utils.imports import import_optional + +torch = import_optional("torch") def torch_import_from_dlpack(dp): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index aad0a552..b89ebe93 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pylibwholegraph.binding.wholememory_binding as wmb -import torch +from pylibwholegraph.utils.imports import MissingModule, import_optional from .utils import torch_dtype_to_wholememory_dtype, get_file_size from .utils import str_to_wmb_wholememory_location, str_to_wmb_wholememory_memory_type from .utils import ( @@ -19,6 +19,15 @@ from .tensor import WholeMemoryTensor from .wholegraph_env import wrap_torch_tensor, get_wholegraph_env_fns, get_stream +torch = import_optional("torch") + +# NOTE: using more specific 'import_optional()' than just 'torch' for import-time checks +# (e.g. those needed for defining base classes) can be helpful because 'torch' can appear +# to be available even after a 'pip uninstall torch' if any files are left behind in +# 'site-packages/torch'. +torch_autograd = import_optional("torch.autograd") +torch_nn = import_optional("torch.nn") + class WholeMemoryOptimizer(object): """ @@ -207,31 +216,60 @@ def create_builtin_cache_policy( ) -class EmbeddingLookupFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - indice: torch.Tensor, - dummy_input: torch.Tensor, - wm_embedding, - is_training: bool = False, - force_dtype: Union[torch.dtype, None] = None, - ): - output_tensor = wm_embedding.gather( - indice, is_training=is_training, force_dtype=force_dtype - ) - if is_training and wm_embedding.need_grad(): - ctx.save_for_backward(indice, output_tensor, dummy_input) - ctx.wm_embedding = wm_embedding - return output_tensor +if not isinstance(torch_autograd, MissingModule): + + class EmbeddingLookupFn(torch_autograd.Function): + @staticmethod + def forward( + ctx, + indice: "torch.Tensor", + dummy_input: "torch.Tensor", + wm_embedding, + is_training: bool = False, + force_dtype: Union["torch.dtype", None] = None, + ): + output_tensor = wm_embedding.gather( + indice, is_training=is_training, force_dtype=force_dtype + ) + if is_training and wm_embedding.need_grad(): + ctx.save_for_backward(indice, output_tensor, dummy_input) + ctx.wm_embedding = wm_embedding + return output_tensor + + @staticmethod + def backward(ctx, grad_outputs: "torch.Tensor"): + indice, output_tensor, dummy_input = ctx.saved_tensors + wm_embedding = ctx.wm_embedding + wm_embedding.add_gradients(indice, grad_outputs) + ctx.wm_embedding = None + return None, torch.zeros_like(dummy_input), None, None, None + +else: + + class EmbeddingLookupFn: + def __init__(self, *args, **kwargs): + raise ModuleNotFoundError( + "EmbeddingLookupFn requires 'torch.autograd'. Install 'torch'." + ) - @staticmethod - def backward(ctx, grad_outputs: torch.Tensor): - indice, output_tensor, dummy_input = ctx.saved_tensors - wm_embedding = ctx.wm_embedding - wm_embedding.add_gradients(indice, grad_outputs) - ctx.wm_embedding = None - return None, torch.zeros_like(dummy_input), None, None, None + @staticmethod + def forward( + ctx, + indice: "torch.Tensor", + dummy_input: "torch.Tensor", + wm_embedding, + is_training: bool = False, + force_dtype: Union["torch.dtype", None] = None, + ): + raise ModuleNotFoundError( + "EmbeddingLookupFn requires 'torch.autograd'. Install 'torch'." + ) + + @staticmethod + def backward(ctx, grad_outputs: "torch.Tensor"): + raise ModuleNotFoundError( + "EmbeddingLookupFn requires 'torch.autograd'. Install 'torch'." + ) class WholeMemoryEmbedding(object): @@ -253,7 +291,7 @@ def __init__( self.wmb_optimizer = None - self.dummy_input = torch.nn.Parameter(torch.zeros(1), requires_grad=False) + self.dummy_input = torch_nn.Parameter(torch.zeros(1), requires_grad=False) self.need_apply = False self.sparse_indices = [] self.sparse_grads = [] @@ -273,10 +311,10 @@ def need_grad(self): def gather( self, - indice: torch.Tensor, + indice: "torch.Tensor", *, is_training: bool = False, - force_dtype: Union[torch.dtype, None] = None, + force_dtype: Union["torch.dtype", None] = None, ): assert indice.dim() == 1 embedding_dim = self.get_embedding_tensor().shape[1] @@ -304,7 +342,7 @@ def gather( ) return output_tensor - def add_gradients(self, indice: torch.Tensor, grad_outputs: torch.Tensor): + def add_gradients(self, indice: "torch.Tensor", grad_outputs: "torch.Tensor"): self.sparse_indices.append(indice) self.sparse_grads.append(grad_outputs) @@ -373,7 +411,7 @@ def create_embedding( comm: WholeMemoryCommunicator, memory_type: str, memory_location: str, - dtype: torch.dtype, + dtype: "torch.dtype", sizes: List[int], *, cache_policy: Union[WholeMemoryCachePolicy, None] = None, @@ -452,7 +490,7 @@ def create_embedding( local_tensor, local_offset, ) = wm_embedding.get_embedding_tensor().get_local_tensor() - torch.nn.init.xavier_uniform_(local_tensor) + torch_nn.init.xavier_uniform_(local_tensor) comm.barrier() return wm_embedding @@ -462,7 +500,7 @@ def create_embedding_from_filelist( memory_type: str, memory_location: str, filelist: Union[List[str], str], - dtype: torch.dtype, + dtype: "torch.dtype", last_dim_size: int, *, cache_policy: Union[WholeMemoryCachePolicy, None] = None, @@ -536,26 +574,35 @@ def destroy_embedding(wm_embedding: WholeMemoryEmbedding): wm_embedding.wmb_embedding = None -class WholeMemoryEmbeddingModule(torch.nn.Module): - """ - torch.nn.Module wrapper of WholeMemoryEmbedding - """ +if not isinstance(torch_nn, MissingModule): - def __init__(self, wm_embedding: WholeMemoryEmbedding): - super().__init__() - self.wm_embedding = wm_embedding - self.embedding_gather_fn = EmbeddingLookupFn.apply + class WholeMemoryEmbeddingModule(torch_nn.Module): + """ + torch.nn.Module wrapper of WholeMemoryEmbedding + """ - def forward( - self, indice: torch.Tensor, force_dtype: Union[torch.dtype, None] = None - ): - return self.embedding_gather_fn( - indice, - self.wm_embedding.dummy_input, - self.wm_embedding, - self.training, - force_dtype, - ) + def __init__(self, wm_embedding: WholeMemoryEmbedding): + super().__init__() + self.wm_embedding = wm_embedding + self.embedding_gather_fn = EmbeddingLookupFn.apply + + def forward( + self, indice: "torch.Tensor", force_dtype: Union["torch.dtype", None] = None + ): + return self.embedding_gather_fn( + indice, + self.wm_embedding.dummy_input, + self.wm_embedding, + self.training, + force_dtype, + ) +else: + + class WholeMemoryEmbeddingModule: + def __init__(self, wm_embedding: WholeMemoryEmbedding): + raise ModuleNotFoundError( + "WholeMemoryEmbeddingModule requires 'torch.nn.Module'. Install 'torch'." + ) def create_wholememory_optimizer( diff --git a/python/pylibwholegraph/pylibwholegraph/torch/gnn_model.py b/python/pylibwholegraph/pylibwholegraph/torch/gnn_model.py index b779862c..c6e2813a 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/gnn_model.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/gnn_model.py @@ -1,12 +1,18 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch +from pylibwholegraph.utils.imports import import_optional, MissingModule from .graph_structure import GraphStructure from .embedding import WholeMemoryEmbedding, WholeMemoryEmbeddingModule from .common_options import parse_max_neighbors -import torch.nn.functional as F +torch = import_optional("torch") + +# NOTE: using more specific 'import_optional()' than just 'torch' for import-time checks +# (e.g. those needed for defining base classes) can be helpful because 'torch' can appear +# to be available even after a 'pip uninstall torch' if any files are left behind in +# 'site-packages/torch'. +torch_nn = import_optional("torch.nn") framework_name = None @@ -28,7 +34,7 @@ def set_framework(framework: str): def create_gnn_layers( in_feat_dim, hidden_feat_dim, class_count, num_layer, num_head, model_type ): - gnn_layers = torch.nn.ModuleList() + gnn_layers = torch_nn.ModuleList() global framework_name for i in range(num_layer): layer_output_dim = ( @@ -119,74 +125,88 @@ def layer_forward(layer, x_feat, x_target_feat, sub_graph): return x_feat -class HomoGNNModel(torch.nn.Module): - def __init__( - self, - graph_structure: GraphStructure, - node_embedding: WholeMemoryEmbedding, - args, - ): - super().__init__() - hidden_feat_dim = args.hiddensize - self.graph_structure = graph_structure - self.node_embedding = node_embedding - self.num_layer = args.layernum - self.hidden_feat_dim = args.hiddensize - num_head = args.heads if (args.model == "gat") else 1 - assert hidden_feat_dim % num_head == 0 - in_feat_dim = self.node_embedding.shape[1] - self.gnn_layers = create_gnn_layers( - in_feat_dim, - hidden_feat_dim, - args.classnum, - args.layernum, - num_head, - args.model, - ) - self.mean_output = True if args.model == "gat" else False - self.add_self_loop = True if args.model == "gat" else False - self.gather_fn = WholeMemoryEmbeddingModule(self.node_embedding) - self.dropout = args.dropout - self.max_neighbors = parse_max_neighbors(args.layernum, args.neighbors) - self.max_inference_neighbors = parse_max_neighbors( - args.layernum, args.inferencesample - ) +if not isinstance(torch_nn, MissingModule): - def forward(self, ids): - global framework_name - max_neighbors = ( - self.max_neighbors if self.training else self.max_inference_neighbors - ) - ids = ids.to(self.graph_structure.csr_col_ind.dtype).cuda() - ( - target_gids, - edge_indice, - csr_row_ptrs, - csr_col_inds, - ) = self.graph_structure.multilayer_sample_without_replacement( - ids, max_neighbors - ) - x_feat = self.gather_fn(target_gids[0], force_dtype=torch.float32) - for i in range(self.num_layer): - x_target_feat = x_feat[: target_gids[i + 1].numel()] - sub_graph = create_sub_graph( - target_gids[i], - target_gids[i + 1], - edge_indice[i], - csr_row_ptrs[i], - csr_col_inds[i], - max_neighbors[self.num_layer - 1 - i], - self.add_self_loop, + class HomoGNNModel(torch_nn.Module): + def __init__( + self, + graph_structure: GraphStructure, + node_embedding: WholeMemoryEmbedding, + args, + ): + super().__init__() + hidden_feat_dim = args.hiddensize + self.graph_structure = graph_structure + self.node_embedding = node_embedding + self.num_layer = args.layernum + self.hidden_feat_dim = args.hiddensize + num_head = args.heads if (args.model == "gat") else 1 + assert hidden_feat_dim % num_head == 0 + in_feat_dim = self.node_embedding.shape[1] + self.gnn_layers = create_gnn_layers( + in_feat_dim, + hidden_feat_dim, + args.classnum, + args.layernum, + num_head, + args.model, + ) + self.mean_output = True if args.model == "gat" else False + self.add_self_loop = True if args.model == "gat" else False + self.gather_fn = WholeMemoryEmbeddingModule(self.node_embedding) + self.dropout = args.dropout + self.max_neighbors = parse_max_neighbors(args.layernum, args.neighbors) + self.max_inference_neighbors = parse_max_neighbors( + args.layernum, args.inferencesample ) - x_feat = layer_forward( - self.gnn_layers[i], - x_feat, - x_target_feat, - sub_graph, + + def forward(self, ids): + global framework_name + max_neighbors = ( + self.max_neighbors if self.training else self.max_inference_neighbors + ) + ids = ids.to(self.graph_structure.csr_col_ind.dtype).cuda() + ( + target_gids, + edge_indice, + csr_row_ptrs, + csr_col_inds, + ) = self.graph_structure.multilayer_sample_without_replacement( + ids, max_neighbors ) - if i != self.num_layer - 1: - x_feat = F.relu(x_feat) - x_feat = F.dropout(x_feat, self.dropout, training=self.training) + x_feat = self.gather_fn(target_gids[0], force_dtype=torch.float32) + for i in range(self.num_layer): + x_target_feat = x_feat[: target_gids[i + 1].numel()] + sub_graph = create_sub_graph( + target_gids[i], + target_gids[i + 1], + edge_indice[i], + csr_row_ptrs[i], + csr_col_inds[i], + max_neighbors[self.num_layer - 1 - i], + self.add_self_loop, + ) + x_feat = layer_forward( + self.gnn_layers[i], + x_feat, + x_target_feat, + sub_graph, + ) + if i != self.num_layer - 1: + x_feat = torch_nn.functional.relu(x_feat) + x_feat = torch_nn.functional.dropout( + x_feat, self.dropout, training=self.training + ) + + out_feat = x_feat + return out_feat +else: - out_feat = x_feat - return out_feat + class HomoGNNModel: + def __init__( + self, + graph_structure: GraphStructure, + node_embedding: WholeMemoryEmbedding, + args, + ): + raise ModuleNotFoundError("HomoGNNModel requires 'torch' to be installed.") diff --git a/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py index ae352444..c2bec6fe 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py @@ -1,7 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 - -import torch +from pylibwholegraph.utils.imports import import_optional import pylibwholegraph.binding.wholememory_binding as wmb from .wholegraph_env import ( get_stream, @@ -10,10 +9,12 @@ wrap_torch_tensor, ) +torch = import_optional("torch") + def append_unique( - target_node_tensor: torch.Tensor, - neighbor_node_tensor: torch.Tensor, + target_node_tensor: "torch.Tensor", + neighbor_node_tensor: "torch.Tensor", need_neighbor_raw_to_unique: bool = False, ): """ @@ -60,7 +61,8 @@ def append_unique( def add_csr_self_loop( - csr_row_ptr_tensor: torch.Tensor, csr_col_ptr_tensor: torch.Tensor + csr_row_ptr_tensor: "torch.Tensor", + csr_col_ptr_tensor: "torch.Tensor", ): """ Add self loop to sampled CSR graph diff --git a/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py b/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py index 700b94c9..bb6d75b3 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py @@ -1,12 +1,14 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch +from pylibwholegraph.utils.imports import import_optional from typing import Union, List from .tensor import WholeMemoryTensor from . import graph_ops from . import wholegraph_ops +torch = import_optional("torch") + class GraphStructure(object): r"""Graph structure storage @@ -67,7 +69,7 @@ def set_edge_attribute(self, attr_name: str, attr_tensor: WholeMemoryTensor): def unweighted_sample_without_replacement_one_hop( self, - center_nodes_tensor: torch.Tensor, + center_nodes_tensor: "torch.Tensor", max_sample_count: int, *, random_seed: Union[int, None] = None, @@ -98,7 +100,7 @@ def unweighted_sample_without_replacement_one_hop( def weighted_sample_without_replacement_one_hop( self, weight_name: str, - center_nodes_tensor: torch.Tensor, + center_nodes_tensor: "torch.Tensor", max_sample_count: int, *, random_seed: Union[int, None] = None, @@ -133,7 +135,7 @@ def weighted_sample_without_replacement_one_hop( def multilayer_sample_without_replacement( self, - node_ids: torch.Tensor, + node_ids: "torch.Tensor", max_neighbors: List[int], weight_name: Union[str, None] = None, ): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 3f83ee64..6523779b 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -1,9 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import os -import torch -import torch.utils.dlpack +from pylibwholegraph.utils.imports import import_optional import pylibwholegraph.binding.wholememory_binding as wmb from .comm import ( set_world_info, @@ -13,6 +12,8 @@ ) from .utils import str_to_wmb_wholememory_log_level +torch = import_optional("torch") + def init( world_rank: int, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index c9950b3e..73710ec8 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pylibwholegraph.binding.wholememory_binding as wmb -import torch +from pylibwholegraph.utils.imports import import_optional from .utils import ( torch_dtype_to_wholememory_dtype, wholememory_dtype_to_torch_dtype, @@ -15,6 +15,7 @@ from .dlpack_utils import torch_import_from_dlpack from .wholegraph_env import wrap_torch_tensor, get_wholegraph_env_fns, get_stream +torch = import_optional("torch") WholeMemoryMemoryType = wmb.WholeMemoryMemoryType WholeMemoryMemoryLocation = wmb.WholeMemoryMemoryLocation @@ -49,7 +50,7 @@ def get_comm(self): ) def gather( - self, indice: torch.Tensor, *, force_dtype: Union[torch.dtype, None] = None + self, indice: "torch.Tensor", *, force_dtype: Union["torch.dtype", None] = None ): assert indice.dim() == 1 embedding_dim = self.shape[1] if self.dim() == 2 else 1 @@ -71,7 +72,7 @@ def gather( ) return output_tensor.view(-1) if self.dim() == 1 else output_tensor - def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor): + def scatter(self, input_tensor: "torch.Tensor", indice: "torch.Tensor"): assert indice.dim() == 1 assert input_tensor.dim() == self.dim() assert indice.shape[0] == input_tensor.shape[0] @@ -201,7 +202,7 @@ def create_wholememory_tensor( memory_type: str, memory_location: str, sizes: List[int], - dtype: torch.dtype, + dtype: "torch.dtype", strides: List[int], tensor_entry_partition: Union[List[int], None] = None, ): @@ -250,7 +251,7 @@ def create_wholememory_tensor_from_filelist( memory_type: str, memory_location: str, filelist: Union[List[str], str], - dtype: torch.dtype, + dtype: "torch.dtype", last_dim_size: int = 0, last_dim_strides: int = -1, tensor_entry_partition: Union[List[int], None] = None, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/utils.py b/python/pylibwholegraph/pylibwholegraph/torch/utils.py index a1b296da..4f27061a 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/utils.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/utils.py @@ -1,15 +1,16 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import pylibwholegraph.binding.wholememory_binding as wmb -import torch +from pylibwholegraph.utils.imports import import_optional import os +torch = import_optional("torch") WholeMemoryDataType = wmb.WholeMemoryDataType -def torch_dtype_to_wholememory_dtype(torch_dtype: torch.dtype): +def torch_dtype_to_wholememory_dtype(torch_dtype: "torch.dtype"): """ Convert torch.dtype to WholeMemoryDataType :param torch_dtype: torch.dtype diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index f59418fe..d9c90a5e 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -1,14 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import os.path import importlib -import torch import pylibwholegraph import pylibwholegraph.binding.wholememory_binding as wmb +from pylibwholegraph.utils.imports import import_optional from typing import Union from .utils import wholememory_dtype_to_torch_dtype, torch_dtype_to_wholememory_dtype +torch = import_optional("torch") +torch_utils_cpp_ext = import_optional("torch.utils.cpp_extension") + default_wholegraph_env_context = None torch_cpp_ext_loaded = False torch_cpp_ext_lib = None @@ -46,7 +49,7 @@ def get_c_context(self): else: return id(self) - def set_tensor(self, t: torch.Tensor): + def set_tensor(self, t: "torch.Tensor"): self.tensor = t def get_handle(self): @@ -154,7 +157,7 @@ def get_wholegraph_env_fns(use_default=True) -> int: return wholegraph_env_context.get_env_fns() -def wrap_torch_tensor(t: Union[torch.Tensor, None]) -> wmb.WrappedLocalTensor: +def wrap_torch_tensor(t: Union["torch.Tensor", None]) -> wmb.WrappedLocalTensor: py_desc = wmb.PyWholeMemoryTensorDescription() wm_t = wmb.WrappedLocalTensor() if t is None: @@ -171,8 +174,6 @@ def get_cpp_extension_src_path(): def compile_cpp_extension(): - import torch.utils.cpp_extension - global torch_cpp_ext_loaded global torch_cpp_ext_lib cpp_extension_path = os.path.join(get_cpp_extension_src_path(), "torch_cpp_ext") @@ -192,7 +193,7 @@ def compile_cpp_extension(): extra_ldflags.append( "".join(["-L", os.path.join(os.environ["LIBWHOLEGRAPH_DIR"], "lib")]) ) - torch.utils.cpp_extension.load( + torch_utils_cpp_ext.load( name="pylibwholegraph.pylibwholegraph_torch_ext", sources=[ os.path.join(cpp_extension_path, "wholegraph_torch_ext.cpp"), diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py index c6808010..70b61ac4 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch import pylibwholegraph.binding.wholememory_binding as wmb +from pylibwholegraph.utils.imports import import_optional from .wholegraph_env import ( get_stream, TorchMemoryContext, @@ -12,11 +12,13 @@ from typing import Union import random +torch = import_optional("torch") + def unweighted_sample_without_replacement( wm_csr_row_ptr_tensor: wmb.PyWholeMemoryTensor, wm_csr_col_ptr_tensor: wmb.PyWholeMemoryTensor, - center_nodes_tensor: torch.Tensor, + center_nodes_tensor: "torch.Tensor", max_sample_count: int, random_seed: Union[int, None] = None, need_center_local_output: bool = False, @@ -85,7 +87,7 @@ def weighted_sample_without_replacement( wm_csr_row_ptr_tensor: wmb.PyWholeMemoryTensor, wm_csr_col_ptr_tensor: wmb.PyWholeMemoryTensor, wm_csr_weight_ptr_tensor: wmb.PyWholeMemoryTensor, - center_nodes_tensor: torch.Tensor, + center_nodes_tensor: "torch.Tensor", max_sample_count: int, random_seed: Union[int, None] = None, need_center_local_output: bool = False, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py index 9cb518c5..dfcf7041 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -import torch import pylibwholegraph.binding.wholememory_binding as wmb +from pylibwholegraph.utils.imports import import_optional from .wholegraph_env import ( get_stream, get_wholegraph_env_fns, @@ -10,10 +10,12 @@ ) from .utils import wholememory_dtype_to_torch_dtype +torch = import_optional("torch") + def wholememory_gather_forward_functor( wholememory_tensor: wmb.PyWholeMemoryTensor, - indices_tensor: torch.Tensor, + indices_tensor: "torch.Tensor", requires_grad=False, torch_output_dtype=None, ): @@ -48,8 +50,8 @@ def wholememory_gather_forward_functor( def wholememory_scatter_functor( - input_tensor: torch.Tensor, - indices_tensor: torch.Tensor, + input_tensor: "torch.Tensor", + indices_tensor: "torch.Tensor", wholememory_tensor: wmb.PyWholeMemoryTensor, ): """ diff --git a/python/pylibwholegraph/pylibwholegraph/utils/imports.py b/python/pylibwholegraph/pylibwholegraph/utils/imports.py new file mode 100644 index 00000000..564a6b97 --- /dev/null +++ b/python/pylibwholegraph/pylibwholegraph/utils/imports.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from importlib import import_module +from importlib.util import find_spec + + +class MissingModule: + """ + Raises RuntimeError when any attribute is accessed on instances of this + class. + + Instances of this class are returned by import_optional() when a module + cannot be found, which allows for code to import optional dependencies, and + have only the code paths that use the module affected. + """ + + def __init__(self, mod_name): + self.name = mod_name + + def __getattr__(self, attr): + raise RuntimeError(f"This feature requires the '{self.name}' package/module") + + +class FoundModule: + def __init__(self, mod): + self.mod = mod + self.imported = False + + def __getattr__(self, attr): + if not self.imported: + self.mod = import_module(self.mod) + self.imported = True + return getattr(self.mod, attr) + + +def import_optional(mod, default_mod_class=MissingModule): + """ + import the "optional" module 'mod' and return the module object or object. + If the import raises ModuleNotFoundError, returns an instance of + default_mod_class. + + This method was written to support importing "optional" dependencies so + code can be written to run even if the dependency is not installed. + + Example + ------- + >> from pylibwholegraph.utils.imports import import_optional + >> torch = import_optional("torch") # torch is not installed + >> torch.set_num_threads(1) + Traceback (most recent call last): + File "", line 1, in + ... + RuntimeError: This feature requires the 'torch' package/module + """ + # this try-except is necessary to handle dotted imports, + # like `import_optional("torch.autograd")` + mod_found = False + try: + mod_found = find_spec(mod) is not None + except ImportError: + mod_found = False + + if mod_found: + return FoundModule(mod) + else: + return default_mod_class(mod_name=mod)