Skip to content

Commit

Permalink
support refactored nvfuser
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Apr 19, 2023
1 parent 324cee0 commit 768406d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
16 changes: 11 additions & 5 deletions csrc/instance_norm_nvfuser_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

#include <torch/extension.h>

// The following header file is found in `PYTORCH_HOME`
#include <aten/src/ATen/native/utils/ParamsHash.h>

#if NVFUSER_THIRDPARTY
#include <kernel_cache.h>
#include <ops/all_ops.h>
#else
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>

#include <aten/src/ATen/native/utils/ParamsHash.h>
#endif

using namespace torch::jit::fuser::cuda;
using namespace at::indexing;
Expand Down Expand Up @@ -85,7 +91,7 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
}
InstanceNormKey forward_key;
setKey(input, weight, run_mean, channels_last, forward_key);
if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) {
if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down Expand Up @@ -130,7 +136,7 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
if (!run_mean.sizes().size()) {
_running_mean = nullptr;
_running_var = nullptr;
}
}
if (!weight.sizes().size()) {
_weight = nullptr;
_bias = nullptr;
Expand Down Expand Up @@ -235,7 +241,7 @@ std::vector<at::Tensor> instance_norm_nvfuser_backward(
if (!run_mean.sizes().size()) {
_running_mean = nullptr;
_running_var = nullptr;
}
}
if (!weight.sizes().size()) {
_weight = nullptr;
}
Expand Down
24 changes: 19 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,27 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
)

if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME):
nvfuser_is_refactored = "nvfuser" in (
os.path.join(d) for d in os.listdir(os.path.join(PYTORCH_HOME, "third_party"))
if os.path.isdir(os.path.join(os.path.join(PYTORCH_HOME, "third_party"), d))
)
print(PYTORCH_HOME)
include_dirs = [PYTORCH_HOME]
if nvfuser_is_refactored:
include_dirs.append(os.path.join(PYTORCH_HOME, "third_party/nvfuser/csrc"))
ext_modules.append(
CUDAExtension('instance_norm_nvfuser_cuda',
['csrc/instance_norm_nvfuser.cpp', 'csrc/instance_norm_nvfuser_kernel.cu'],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros,
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])},
)
CUDAExtension(
name='instance_norm_nvfuser_cuda',
sources=[
'csrc/instance_norm_nvfuser.cpp',
'csrc/instance_norm_nvfuser_kernel.cu',
],
include_dirs=include_dirs,
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": ["-O3"] + version_dependent_macros + [f"-DNVFUSER_THIRDPARTY={int(nvfuser_is_refactored)}"],
},
)
)

if "--permutation_search" in sys.argv:
Expand Down

0 comments on commit 768406d

Please sign in to comment.