Skip to content

Commit

Permalink
Merge pull request #105 from pytorch/python-agnostic
Browse files Browse the repository at this point in the history
remove pybind usage, make example python agnostic
  • Loading branch information
janeyx99 authored Jan 22, 2025
2 parents 8fe0de2 + e4c4eb8 commit 4be2205
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
10 changes: 9 additions & 1 deletion extension_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
import torch
from . import _C, ops
from pathlib import Path

so_files = list(Path(__file__).parent.glob("_C*.so"))
assert (
len(so_files) == 1
), f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])

from . import ops
3 changes: 0 additions & 3 deletions extension_cpp/csrc/muladd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
}
}

// Registers _C as a Python extension module.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

// Defines the operators
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

library_name = "extension_cpp"

if torch.__version__ >= "2.6.0":
py_limited_api = True
else:
py_limited_api = False


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
Expand Down Expand Up @@ -59,6 +64,7 @@ def get_extensions():
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
py_limited_api=py_limited_api,
)
]

Expand All @@ -71,9 +77,10 @@ def get_extensions():
packages=find_packages(),
ext_modules=get_extensions(),
install_requires=["torch"],
description="Example of PyTorch cpp and CUDA extensions",
description="Example of PyTorch C++ and CUDA extensions",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/extension-cpp",
cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
)

0 comments on commit 4be2205

Please sign in to comment.