Skip to content

Commit

Permalink
Merge pull request #109 from pytorch/_C-instead-of-load_library
Browse files Browse the repository at this point in the history
Create dummymodule _C instead of using load_library(blah.so)
  • Loading branch information
janeyx99 authored Jan 24, 2025
2 parents 4be2205 + 58ac996 commit 38ec45e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
9 changes: 1 addition & 8 deletions extension_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import torch
from pathlib import Path

so_files = list(Path(__file__).parent.glob("_C*.so"))
assert (
len(so_files) == 1
), f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])

from . import ops
from . import _C, ops
24 changes: 23 additions & 1 deletion extension_cpp/csrc/muladd.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
#include <torch/extension.h>
#include <Python.h>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

#include <vector>

extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so consisting of this file
in this extension, so that the TORCH_LIBRARY static initializers
below are run. */
PyObject* PyInit__C(void)
{
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
}

namespace extension_cpp {

at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_extensions():
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000", # min CPython version 3.9
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down

0 comments on commit 38ec45e

Please sign in to comment.