From 5edf81b2f447522dff0940863d4559fb9abbddc8 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 25 Feb 2022 23:37:44 +0000 Subject: [PATCH 01/15] initial check in --- apex/normalization/__init__.py | 1 + apex/normalization/instance_norm.py | 149 ++++++++++ csrc/instance_norm_nvfuser.cpp | 35 +++ csrc/instance_norm_nvfuser_kernel.cu | 269 ++++++++++++++++++ setup.py | 12 + .../L0/run_instance_norm_nvfuser/__init__.py | 0 .../test_instance_norm_nvfuser.py | 59 ++++ 7 files changed, 525 insertions(+) create mode 100644 apex/normalization/instance_norm.py create mode 100644 csrc/instance_norm_nvfuser.cpp create mode 100644 csrc/instance_norm_nvfuser_kernel.cu create mode 100644 tests/L0/run_instance_norm_nvfuser/__init__.py create mode 100644 tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py diff --git a/apex/normalization/__init__.py b/apex/normalization/__init__.py index c649913fd..8bd4f159f 100644 --- a/apex/normalization/__init__.py +++ b/apex/normalization/__init__.py @@ -1 +1,2 @@ from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm +from .instance_norm import InstanceNorm3dNVFuser diff --git a/apex/normalization/instance_norm.py b/apex/normalization/instance_norm.py new file mode 100644 index 000000000..48870b33e --- /dev/null +++ b/apex/normalization/instance_norm.py @@ -0,0 +1,149 @@ +import importlib + +import torch +from torch import Tensor +from torch.nn.modules.batchnorm import _NormBase + +global instance_norm_nvfuser_cuda +instance_norm_nvfuser_cuda = None + +class InstanceNormNVFuserFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias, running_mean, running_var, + use_input_stats, momentum, eps): + global instance_norm_nvfuser_cuda + if instance_norm_nvfuser_cuda is None: + instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda") + + channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous(memory_format=torch.channels_last_3d) + if channels_last: + order = [0] + [i for i in range(2, len(input.shape))] + [1] + _input = input.permute(order) + else: + _input = input + assert _input.is_contiguous() + result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var, + use_input_stats, momentum, eps, channels_last) + if len(result) == 3: + out, mean, invstd = result + else: + running_mean, running_var, out, mean, invstd = result + ctx.use_input_stats = use_input_stats + ctx.eps = eps + ctx.channels_last = channels_last + # saving for backward in "explicit channels-last format" + ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd) + if channels_last: + order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)] + out = out.permute(order) + if len(out.shape) == 4: + assert out.is_contiguous(memory_format=torch.channels_last) + assert input.is_contiguous(memory_format=torch.channels_last) + elif len(out.shape) == 5: + assert out.is_contiguous(memory_format=torch.channels_last_3d) + assert input.is_contiguous(memory_format=torch.channels_last_3d) + else: + assert False, "unhandled channels_last format variation in forward" + return out + + @staticmethod + def backward(ctx, grad_output): + global instance_norm_nvfuser_cuda + if instance_norm_nvfuser_cuda is None: + instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda") + + if ctx.channels_last: + order = [0] + [i for i in range(2, len(grad_output.shape))] + [1] + grad_output = grad_output.permute(order) + # input was saved in "explicit channels-last format" + assert ctx.saved_tensors[0].is_contiguous() + grad_output = grad_output.contiguous() + saved = list(ctx.saved_tensors) + saved.insert(1, grad_output) + running_mean = saved[3] + running_var = saved[4] + mean = saved[-2] + var = saved[-1] + grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward(*saved, ctx.use_input_stats, ctx.eps, ctx.channels_last) + if ctx.channels_last: + order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)] + grad_input = grad_input.permute(order) + if len(grad_input.shape) == 4: + assert grad_input.is_contiguous(memory_format=torch.channels_last) + elif len(grad_input.shape) == 5: + assert grad_input.is_contiguous(memory_format=torch.channels_last_3d) + else: + assert False, "unhandled channels_last format variation in backward" + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +class _InstanceNormNVFuser(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(_InstanceNormNVFuser, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) + self.dummy = torch.empty([], device='cuda') + + def _check_input_dim(self, input): + raise NotImplementedError + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ('running_mean', 'running_var'): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + 'Unexpected running stats buffer(s) {names} for {klass} ' + 'with track_running_stats=False. If state_dict is a ' + 'checkpoint saved before 0.4.0, this may be expected ' + 'because {klass} does not track running stats by default ' + 'since 0.4.0. Please remove these keys from state_dict. If ' + 'the running stats are actually needed, instead set ' + 'track_running_stats=True in {klass} to enable them. See ' + 'the documentation of {klass} for details.' + .format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys), + klass=self.__class__.__name__)) + for key in running_stats_keys: + state_dict.pop(key) + + super(_InstanceNormNVFuser, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, input: Tensor) -> Tensor: + assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" + self._check_input_dim(input) + if self.running_mean is not None: + out = InstanceNormNVFuserFunction.apply( + input, self.weight if self.weight is not None else self.dummy, + self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var, + self.training or not self.track_running_stats, self.momentum, self.eps) + else: + out = InstanceNormNVFuserFunction.apply( + input, self.weight if self.weight is not None else self.dummy, + self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy, + self.training or not self.track_running_stats, self.momentum, self.eps) + return out + +class InstanceNorm3dNVFuser(_InstanceNormNVFuser): + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + diff --git a/csrc/instance_norm_nvfuser.cpp b/csrc/instance_norm_nvfuser.cpp new file mode 100644 index 000000000..3af0f28be --- /dev/null +++ b/csrc/instance_norm_nvfuser.cpp @@ -0,0 +1,35 @@ +#include +#include + +#include + +std::vector instance_norm_nvfuser_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor run_mean, + at::Tensor run_var, + const bool use_input_stats, + const float momentum, + const float eps, + const bool channels_last = false + ); + +std::vector instance_norm_nvfuser_backward( + at::Tensor input, + at::Tensor grad_output, + at::Tensor weight, + at::Tensor running_mean, + at::Tensor running_var, + at::Tensor save_mean, + at::Tensor save_invstd, + const bool use_input_stats, + const float eps, + // const std::vector& output_mask, + bool channels_last = false + ); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &instance_norm_nvfuser_forward, "instance_norm forward (CUDA)"); + m.def("backward", &instance_norm_nvfuser_backward, "instance_norm backward (CUDA)"); +} diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu new file mode 100644 index 000000000..3c033c544 --- /dev/null +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -0,0 +1,269 @@ +#include +#include +#include + +#include + +#include +#include + +// Hashing machinery for Params +// Fowler–Noll–Vo hash function +// see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct ParamsHash { + // Params must be a POD because we read out its memory + // contenst as char* when hashing + static_assert(std::is_pod::value, "Params is not POD"); + + size_t operator()(const Params& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < (int)sizeof(Params); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +template +struct ParamsEqual { + // Params must be a POD because we read out its memory + // contenst as char* when comparing + static_assert(std::is_pod::value, "Params is not POD"); + + bool operator()(const Params& a, const Params& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Params)) == 0; + } +}; + +using namespace torch::jit::fuser::cuda; +using namespace at::indexing; + +// Make a tensor that is known to be fully contiguous of dimensionality=ndims, +// but unknown sizes +TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { + return TensorViewBuilder() + .ndims(ndims) + .dtype(dtype) + .contiguity(std::vector(ndims, true)) + .build(); +} + +struct InstanceNormKey { + c10::ScalarType input_dtype;//int8_t dtype; + c10::ScalarType weight_dtype; + c10::ScalarType mean_dtype; + size_t dim; + bool channels_last; + bool running_mean; + bool affine; + float eps; +}; + +auto get_dtype(c10::ScalarType dtype) { + auto ret_dtype = DataType::Float; + if (dtype == c10::ScalarType::Double) { + ret_dtype = DataType::Double; + } else if (dtype == c10::ScalarType::Half) { + ret_dtype = DataType::Half; + } else if (dtype == c10::ScalarType::BFloat16) { + ret_dtype = DataType::BFloat16; + } + return ret_dtype; +} + +// TODO: doesn't support all combinations of dtype e.g., bias, run_var, .. +// bias is assumed to match weight, run_var is assumed to match run_mean +void getKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) { + memset(&key, 0, sizeof(InstanceNormKey)); + key.input_dtype = input.scalar_type();// static_cast(input.scalar_type()); + key.weight_dtype = weight.scalar_type(); + key.mean_dtype = run_mean.scalar_type(); + key.dim = input.sizes().size(); + key.channels_last = channels_last; + key.eps = eps; + key.running_mean = run_mean.sizes().size() > 0; + key.affine = weight.sizes().size() ? true : false; +} + +std::unordered_map, ParamsHash, ParamsEqual > forward_fusion_cache; +std::unordered_map, ParamsHash, ParamsEqual > backward_fusion_cache; + +std::vector instance_norm_nvfuser_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor run_mean, + at::Tensor run_var, + const bool use_input_stats, + const float momentum, + const float eps, + const bool channels_last) { + InstanceNormKey forward_key; + memset(&forward_key, 0, sizeof(InstanceNormKey)); + getKey(input, weight, run_mean, channels_last, eps, forward_key); + if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto _input_dtype = get_dtype(input.scalar_type()); + const auto _weight_dtype = get_dtype(weight.scalar_type()); + const auto _bias_dtype = get_dtype(bias.scalar_type()); + const auto _running_mean_dtype = get_dtype(run_mean.scalar_type()); + const auto _running_var_dtype = get_dtype(run_var.scalar_type()); + auto _input = makeContigTensor(input.sizes().size(), _input_dtype); + auto _weight = makeContigTensor(weight.sizes().size(), _weight_dtype); + auto _bias = makeContigTensor(bias.sizes().size(), _bias_dtype); + auto _running_mean = makeContigTensor(run_mean.sizes().size(), get_dtype(run_mean.scalar_type())); + auto _running_var = makeContigTensor(run_var.sizes().size(), get_dtype(run_var.scalar_type())); + + fusion->addInput(_input); + fusion->addInput(_weight); + fusion->addInput(_bias); + + if (_input_dtype == DataType::Half || _input_dtype == DataType::BFloat16) { + _input = castOp(DataType::Float, _input); + } + if (_weight_dtype == DataType::Half || _weight_dtype == DataType::BFloat16) { + _weight = castOp(DataType::Float, _weight); + } + if (_bias_dtype == DataType::Half || _bias_dtype == DataType::BFloat16) { + _bias = castOp(DataType::Float, _bias); + } + + // TODO: decide if passing an empty tensor is the best way to signal no running mean/var + if (run_mean.sizes().size()) { + fusion->addInput(_running_mean); + fusion->addInput(_running_var); + // casting is done by Forward for running mean/var as it needs original inputs for aliasing + } + + Double* _momentum = IrBuilder::create(momentum); + Double* _eps = IrBuilder::create(eps); + + ForwardNormResult result; + if (!run_mean.sizes().size()) { + _running_mean = nullptr; + _running_var = nullptr; + } + if (!weight.sizes().size()) { + _weight = nullptr; + _bias = nullptr; + } + result = instance_norm( + _input, _weight, _bias, _running_mean, _running_var, use_input_stats, _momentum, _eps, channels_last); + + if (_input_dtype == DataType::Half || _input_dtype == DataType::BFloat16) { + fusion->addOutput(castOp(_input_dtype, result.output)); + fusion->addOutput(castOp(_input_dtype, result.mean)); + fusion->addOutput(castOp(_input_dtype, result.invstd)); + } else { + fusion->addOutput(result.output); + fusion->addOutput(result.mean); + fusion->addOutput(result.invstd); + } + forward_fusion_cache.emplace(forward_key, std::make_unique(std::move(fusion))); // need std::move right + } + std::vector aten_inputs = {input, weight, bias}; + if (run_mean.sizes().size()) { + aten_inputs.push_back(run_mean); + aten_inputs.push_back(run_var); + } + return forward_fusion_cache[forward_key].get()->runFusionWithInputs(aten_inputs); +} + +std::vector instance_norm_nvfuser_backward( + at::Tensor input, + at::Tensor grad_output, + at::Tensor weight, + at::Tensor run_mean, + at::Tensor run_var, + at::Tensor save_mean, + at::Tensor save_invstd, + const bool use_input_stats, + const float eps, + // const std::vector& output_mask, + bool channels_last + ) { + InstanceNormKey backward_key; + memset(&backward_key, 0, sizeof(InstanceNormKey)); + getKey(input, weight, run_mean, channels_last, eps, backward_key); + if (backward_fusion_cache.find(backward_key) == backward_fusion_cache.end()) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto _input_dtype = get_dtype(input.scalar_type()); + const auto _grad_output_dtype = get_dtype(grad_output.scalar_type()); + const auto _weight_dtype = get_dtype(weight.scalar_type()); + const auto _running_mean_dtype = get_dtype(run_mean.scalar_type()); + const auto _running_var_dtype = get_dtype(run_var.scalar_type()); + auto _input = makeContigTensor(input.sizes().size(), _input_dtype); + auto _grad_output = makeContigTensor(grad_output.sizes().size(), _grad_output_dtype); + auto _weight = makeContigTensor(weight.sizes().size(), _weight_dtype); + auto _running_mean = makeContigTensor(run_mean.sizes().size(), get_dtype(run_mean.scalar_type())); + auto _running_var = makeContigTensor(run_var.sizes().size(), get_dtype(run_var.scalar_type())); + auto _save_mean = makeContigTensor(save_mean.sizes().size(), get_dtype(save_mean.scalar_type())); + auto _save_invstd = makeContigTensor(save_invstd.sizes().size(), get_dtype(save_invstd.scalar_type())); + + fusion->addInput(_input); + fusion->addInput(_grad_output); + fusion->addInput(_weight); + fusion->addInput(_running_mean); + fusion->addInput(_running_var); + fusion->addInput(_save_mean); + fusion->addInput(_save_invstd); + + if (_input_dtype == DataType::Half || _input_dtype == DataType::BFloat16) { + _input = castOp(DataType::Float, _input); + } + if (_grad_output_dtype == DataType::Half || _grad_output_dtype == DataType::BFloat16) { + _grad_output = castOp(DataType::Float, _grad_output); + } + if (_weight_dtype == DataType::Half || _weight_dtype == DataType::BFloat16) { + _weight = castOp(DataType::Float, _weight); + } + if (_running_mean_dtype == DataType::Half || _running_mean_dtype == DataType::BFloat16) { + _running_mean = castOp(DataType::Float, _running_mean); + } + if (_running_var_dtype == DataType::Half || _running_var_dtype == DataType::BFloat16) { + _running_var = castOp(DataType::Float, _running_var); + } + + + Double* _eps = IrBuilder::create(eps); + if (!run_mean.sizes().size()) { + _running_mean = nullptr; + _running_var = nullptr; + } + if (!weight.sizes().size()) { + _weight = nullptr; + } + auto result = instance_norm_backward(_input, + _grad_output, + _weight, + _running_mean, + _running_var, + _save_mean, + _save_invstd, + use_input_stats, + _eps, + {true, true, true}, // TODO: is output mask useful? + channels_last); + if (_input_dtype == DataType::Half || _input_dtype == DataType::BFloat16) { + fusion->addOutput(castOp(_input_dtype, result.grad_input)); + fusion->addOutput(castOp(_input_dtype, result.grad_weight)); + fusion->addOutput(castOp(_input_dtype, result.grad_bias)); + } else { + fusion->addOutput(result.grad_input); + fusion->addOutput(result.grad_weight); + fusion->addOutput(result.grad_bias); + } + backward_fusion_cache.emplace(backward_key, std::make_unique(std::move(fusion))); + } + std::vector aten_inputs = { + input, grad_output, weight, run_mean, run_var, save_mean, save_invstd}; + return backward_fusion_cache[backward_key].get()->runFusionWithInputs(aten_inputs); + } diff --git a/setup.py b/setup.py index cb1a79067..8f31779aa 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,8 @@ import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, load +PYTORCH_HOME = os.path.abspath(os.environ['PYTORCH_HOME']) if 'PYTORCH_HOME' in os.environ else None + # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -358,6 +360,16 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ) +if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME): + print(PYTORCH_HOME) + ext_modules.append( + CUDAExtension('instance_norm_nvfuser_cuda', + ['csrc/instance_norm_nvfuser.cpp', 'csrc/instance_norm_nvfuser_kernel.cu'], + extra_compile_args={"cxx": ["-O3"] + version_dependent_macros, + "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])}, + ) + ) + if "--permutation_search" in sys.argv: sys.argv.remove("--permutation_search") diff --git a/tests/L0/run_instance_norm_nvfuser/__init__.py b/tests/L0/run_instance_norm_nvfuser/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py new file mode 100644 index 000000000..797c9a8f6 --- /dev/null +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -0,0 +1,59 @@ +import itertools +import unittest + +import torch + +import apex +from apex.normalization import InstanceNorm3dNVFuser + +class TestInstanceNormNVFuser(unittest.TestCase): + dtype = torch.float + track_running_stats = False + channels_last = False + affine = False + batch_size = 5 + channel_size = 7 + spatial_size = 3 + + def setUp(self): + self.m = InstanceNorm3dNVFuser(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype) + self.reference_m = torch.nn.InstanceNorm3d(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype) + + def check_same_output(self): + torch.manual_seed(42) + for i in range(2): # exercise JIT + caching + inp = torch.randint(0, 2, (self.batch_size, self.channel_size, self.spatial_size, self.spatial_size, self.spatial_size), device='cuda', requires_grad=True, dtype=self.dtype) + inp2 = inp.detach().clone() + inp2.requires_grad = True + if self.channels_last: + _inp = inp.to(memory_format=torch.channels_last_3d) + else: + _inp = inp + out = self.m(_inp) + (out.sum()).backward() + out2 = self.reference_m(inp2) + if self.m.running_mean is None: + assert self.reference_m.running_mean is None + assert self.m.running_var is None + assert self.reference_m.running_var is None + else: + torch.testing.assert_close(self.m.running_mean, self.reference_m.running_mean) + if self.dtype == torch.float16: + torch.testing.assert_close(self.m.running_var, self.reference_m.running_var, atol=5e-3, rtol=5e-3) + else: + torch.testing.assert_close(self.m.running_var, self.reference_m.running_var) + torch.testing.assert_close(out, out2) + (out2.sum()).backward() + if self.dtype == torch.float16: + torch.testing.assert_close(inp.grad, inp2.grad, atol=5e-3, rtol=5e-3) + else: + torch.testing.assert_close(inp.grad, inp2.grad) + + def test_sweep(self): + for dtype, track_running_stats, channels_last, affine in itertools.product((torch.float, torch.half), (False, True), (False, True), (False, True)): + self.dtype = dtype + self.track_running_stats = track_running_stats + self.channels_last = channels_last + self.affine = affine + self.setUp() + self.check_same_output() From 791d8151aa098be5fe0d894b9fe5e551736d9848 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 1 Mar 2022 19:28:58 +0000 Subject: [PATCH 02/15] add weight and bias check --- .../test_instance_norm_nvfuser.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py index 797c9a8f6..48b0ad89c 100644 --- a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -48,6 +48,16 @@ def check_same_output(self): torch.testing.assert_close(inp.grad, inp2.grad, atol=5e-3, rtol=5e-3) else: torch.testing.assert_close(inp.grad, inp2.grad) + if self.m.weight is not None: + if self.dtype == torch.float16: + torch.testing.assert_close(self.m.weight.grad, self.reference_m.weight.grad, atol=5e-2, rtol=5e-2) + else: + torch.testing.assert_close(self.m.weight.grad, self.reference_m.weight.grad) + if self.m.bias is not None: + if self.dtype == torch.float16: + torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad, atol=5e-3, rtol=5e-3) + else: + torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad) def test_sweep(self): for dtype, track_running_stats, channels_last, affine in itertools.product((torch.float, torch.half), (False, True), (False, True), (False, True)): From a766a1b41e3903a9150c02e23592fcf14cbeed7c Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 10 Mar 2022 01:30:08 +0000 Subject: [PATCH 03/15] address comments, cleanup --- csrc/instance_norm_nvfuser_kernel.cu | 44 +++---------------- .../test_instance_norm_nvfuser.py | 9 ++-- 2 files changed, 12 insertions(+), 41 deletions(-) diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu index 3c033c544..6ef005f75 100644 --- a/csrc/instance_norm_nvfuser_kernel.cu +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -7,38 +7,7 @@ #include #include -// Hashing machinery for Params -// Fowler–Noll–Vo hash function -// see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function -template -struct ParamsHash { - // Params must be a POD because we read out its memory - // contenst as char* when hashing - static_assert(std::is_pod::value, "Params is not POD"); - - size_t operator()(const Params& params) const { - auto ptr = reinterpret_cast(¶ms); - uint32_t value = 0x811C9DC5; - for (int i = 0; i < (int)sizeof(Params); ++i) { - value ^= ptr[i]; - value *= 0x01000193; - } - return (size_t)value; - } -}; - -template -struct ParamsEqual { - // Params must be a POD because we read out its memory - // contenst as char* when comparing - static_assert(std::is_pod::value, "Params is not POD"); - - bool operator()(const Params& a, const Params& b) const { - auto ptr1 = reinterpret_cast(&a); - auto ptr2 = reinterpret_cast(&b); - return memcmp(ptr1, ptr2, sizeof(Params)) == 0; - } -}; +#include using namespace torch::jit::fuser::cuda; using namespace at::indexing; @@ -78,7 +47,7 @@ auto get_dtype(c10::ScalarType dtype) { // TODO: doesn't support all combinations of dtype e.g., bias, run_var, .. // bias is assumed to match weight, run_var is assumed to match run_mean -void getKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) { +void setKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) { memset(&key, 0, sizeof(InstanceNormKey)); key.input_dtype = input.scalar_type();// static_cast(input.scalar_type()); key.weight_dtype = weight.scalar_type(); @@ -90,8 +59,8 @@ void getKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& key.affine = weight.sizes().size() ? true : false; } -std::unordered_map, ParamsHash, ParamsEqual > forward_fusion_cache; -std::unordered_map, ParamsHash, ParamsEqual > backward_fusion_cache; +std::unordered_map, at::native::ParamsHash, at::native::ParamsEqual > forward_fusion_cache; +std::unordered_map, at::native::ParamsHash, at::native::ParamsEqual > backward_fusion_cache; std::vector instance_norm_nvfuser_forward( at::Tensor input, @@ -104,8 +73,7 @@ std::vector instance_norm_nvfuser_forward( const float eps, const bool channels_last) { InstanceNormKey forward_key; - memset(&forward_key, 0, sizeof(InstanceNormKey)); - getKey(input, weight, run_mean, channels_last, eps, forward_key); + setKey(input, weight, run_mean, channels_last, eps, forward_key); if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -166,7 +134,7 @@ std::vector instance_norm_nvfuser_forward( fusion->addOutput(result.mean); fusion->addOutput(result.invstd); } - forward_fusion_cache.emplace(forward_key, std::make_unique(std::move(fusion))); // need std::move right + forward_fusion_cache.emplace(forward_key, std::make_unique(std::move(fusion))); } std::vector aten_inputs = {input, weight, bias}; if (run_mean.sizes().size()) { diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py index 48b0ad89c..8b2866ae6 100644 --- a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -15,7 +15,7 @@ class TestInstanceNormNVFuser(unittest.TestCase): channel_size = 7 spatial_size = 3 - def setUp(self): + def init_modules(self): self.m = InstanceNorm3dNVFuser(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype) self.reference_m = torch.nn.InstanceNorm3d(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype) @@ -60,10 +60,13 @@ def check_same_output(self): torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad) def test_sweep(self): - for dtype, track_running_stats, channels_last, affine in itertools.product((torch.float, torch.half), (False, True), (False, True), (False, True)): + dtypes = [torch.float, torch.half] + if torch.cuda.get_device_capability() >= (8, 0): + dtypes.append(torch.bfloat16) + for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)): self.dtype = dtype self.track_running_stats = track_running_stats self.channels_last = channels_last self.affine = affine - self.setUp() + self.init_modules() self.check_same_output() From b893d03e2d46f12f5e4b73186a58bb115ea3dd83 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 10 Mar 2022 01:50:38 +0000 Subject: [PATCH 04/15] fix --- csrc/instance_norm_nvfuser_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu index 6ef005f75..a88ad5182 100644 --- a/csrc/instance_norm_nvfuser_kernel.cu +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -159,7 +159,7 @@ std::vector instance_norm_nvfuser_backward( ) { InstanceNormKey backward_key; memset(&backward_key, 0, sizeof(InstanceNormKey)); - getKey(input, weight, run_mean, channels_last, eps, backward_key); + setKey(input, weight, run_mean, channels_last, eps, backward_key); if (backward_fusion_cache.find(backward_key) == backward_fusion_cache.end()) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From ccd652d2d929fc6cfc7e7b60e133464faa17f1ea Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 10 Mar 2022 02:57:13 +0000 Subject: [PATCH 05/15] sketchy test numerics twiddling --- .../test_instance_norm_nvfuser.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py index 8b2866ae6..c976c767f 100644 --- a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -30,7 +30,6 @@ def check_same_output(self): else: _inp = inp out = self.m(_inp) - (out.sum()).backward() out2 = self.reference_m(inp2) if self.m.running_mean is None: assert self.reference_m.running_mean is None @@ -43,18 +42,24 @@ def check_same_output(self): else: torch.testing.assert_close(self.m.running_var, self.reference_m.running_var) torch.testing.assert_close(out, out2) - (out2.sum()).backward() + grad_out = torch.randn_like(inp) + out.backward(grad_out) + out2.backward(grad_out) if self.dtype == torch.float16: torch.testing.assert_close(inp.grad, inp2.grad, atol=5e-3, rtol=5e-3) + elif self.dtype == torch.bfloat16: + torch.testing.assert_close(inp.grad, inp2.grad, atol=2e-2, rtol=2e-2) else: torch.testing.assert_close(inp.grad, inp2.grad) if self.m.weight is not None: if self.dtype == torch.float16: torch.testing.assert_close(self.m.weight.grad, self.reference_m.weight.grad, atol=5e-2, rtol=5e-2) + elif self.dtype == torch.bfloat16: + torch.testing.assert_close(self.m.weight.grad, self.reference_m.weight.grad, atol=7e-2, rtol=8e-2) else: torch.testing.assert_close(self.m.weight.grad, self.reference_m.weight.grad) if self.m.bias is not None: - if self.dtype == torch.float16: + if self.dtype in (torch.float16, torch.bfloat16): torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad, atol=5e-3, rtol=5e-3) else: torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad) From 1fb051244ad3b3149a3bf0858a4eaf3c84b1ebf0 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 15 Mar 2022 23:48:19 +0000 Subject: [PATCH 06/15] add profile, remove scalars from cache key --- csrc/instance_norm_nvfuser_kernel.cu | 66 +++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu index a88ad5182..673af68f7 100644 --- a/csrc/instance_norm_nvfuser_kernel.cu +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include @@ -12,6 +13,15 @@ using namespace torch::jit::fuser::cuda; using namespace at::indexing; +std::chrono::time_point t1; +std::chrono::time_point t2; +std::chrono::time_point t3; + +bool profile() { + static bool should_profile = std::getenv("APEX_NVFUSER_PROFILE") != nullptr; + return should_profile; +} + // Make a tensor that is known to be fully contiguous of dimensionality=ndims, // but unknown sizes TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { @@ -30,7 +40,6 @@ struct InstanceNormKey { bool channels_last; bool running_mean; bool affine; - float eps; }; auto get_dtype(c10::ScalarType dtype) { @@ -47,14 +56,13 @@ auto get_dtype(c10::ScalarType dtype) { // TODO: doesn't support all combinations of dtype e.g., bias, run_var, .. // bias is assumed to match weight, run_var is assumed to match run_mean -void setKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) { +void setKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, InstanceNormKey& key) { memset(&key, 0, sizeof(InstanceNormKey)); key.input_dtype = input.scalar_type();// static_cast(input.scalar_type()); key.weight_dtype = weight.scalar_type(); key.mean_dtype = run_mean.scalar_type(); key.dim = input.sizes().size(); key.channels_last = channels_last; - key.eps = eps; key.running_mean = run_mean.sizes().size() > 0; key.affine = weight.sizes().size() ? true : false; } @@ -72,8 +80,11 @@ std::vector instance_norm_nvfuser_forward( const float momentum, const float eps, const bool channels_last) { + if (profile()) { + t1 = std::chrono::steady_clock::now(); + } InstanceNormKey forward_key; - setKey(input, weight, run_mean, channels_last, eps, forward_key); + setKey(input, weight, run_mean, channels_last, forward_key); if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -110,8 +121,10 @@ std::vector instance_norm_nvfuser_forward( // casting is done by Forward for running mean/var as it needs original inputs for aliasing } - Double* _momentum = IrBuilder::create(momentum); - Double* _eps = IrBuilder::create(eps); + Double* _momentum = IrBuilder::create(); + Double* _eps = IrBuilder::create(); + fusion->addInput(_momentum); + fusion->addInput(_eps); ForwardNormResult result; if (!run_mean.sizes().size()) { @@ -141,7 +154,21 @@ std::vector instance_norm_nvfuser_forward( aten_inputs.push_back(run_mean); aten_inputs.push_back(run_var); } - return forward_fusion_cache[forward_key].get()->runFusionWithInputs(aten_inputs); + aten_inputs.push_back(momentum); + aten_inputs.push_back(eps); + if (profile()) { + t2 = std::chrono::steady_clock::now(); + } + auto r = forward_fusion_cache[forward_key].get()->runFusionWithInputs(aten_inputs); + if (profile()) { + t3 = std::chrono::steady_clock::now(); + std::chrono::duration full = t3 - t1; + std::chrono::duration pre = t2 - t1; + std::chrono::duration exec = t3 - t2; + std::cout << "NVFuserInstanceNorm Forward (full, pre-exec, exec) (" << full.count() + << ", " << pre.count() << ", " << exec.count() << ")" << std::endl; + } + return r; } std::vector instance_norm_nvfuser_backward( @@ -157,9 +184,12 @@ std::vector instance_norm_nvfuser_backward( // const std::vector& output_mask, bool channels_last ) { + if (profile()) { + t1 = std::chrono::steady_clock::now(); + } InstanceNormKey backward_key; memset(&backward_key, 0, sizeof(InstanceNormKey)); - setKey(input, weight, run_mean, channels_last, eps, backward_key); + setKey(input, weight, run_mean, channels_last, backward_key); if (backward_fusion_cache.find(backward_key) == backward_fusion_cache.end()) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -200,8 +230,8 @@ std::vector instance_norm_nvfuser_backward( _running_var = castOp(DataType::Float, _running_var); } - - Double* _eps = IrBuilder::create(eps); + Double* _eps = IrBuilder::create(); + fusion->addInput(_eps); if (!run_mean.sizes().size()) { _running_mean = nullptr; _running_var = nullptr; @@ -232,6 +262,18 @@ std::vector instance_norm_nvfuser_backward( backward_fusion_cache.emplace(backward_key, std::make_unique(std::move(fusion))); } std::vector aten_inputs = { - input, grad_output, weight, run_mean, run_var, save_mean, save_invstd}; - return backward_fusion_cache[backward_key].get()->runFusionWithInputs(aten_inputs); + input, grad_output, weight, run_mean, run_var, save_mean, save_invstd, eps}; + if (profile()) { + t2 = std::chrono::steady_clock::now(); + } + auto r = backward_fusion_cache[backward_key].get()->runFusionWithInputs(aten_inputs); + if (profile()) { + t3 = std::chrono::steady_clock::now(); + std::chrono::duration full = t3 - t1; + std::chrono::duration pre = t2 - t1; + std::chrono::duration exec = t3 - t2; + std::cout << "NVFuserInstanceNorm Backward (full, pre-exec, exec) (" << full.count() + << ", " << pre.count() << ", " << exec.count() << ")" << std::endl; + } + return r; } From 57382e988c084abf136a8ec95da36662f9a8a359 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Mon, 25 Jul 2022 19:50:11 +0000 Subject: [PATCH 07/15] retab --- csrc/instance_norm_nvfuser_kernel.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu index 673af68f7..6d9559d07 100644 --- a/csrc/instance_norm_nvfuser_kernel.cu +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -242,9 +242,9 @@ std::vector instance_norm_nvfuser_backward( auto result = instance_norm_backward(_input, _grad_output, _weight, - _running_mean, + _running_mean, _running_var, - _save_mean, + _save_mean, _save_invstd, use_input_stats, _eps, From d94d55ac4b8871c6a506a36ae272408fc43251d7 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Mon, 25 Jul 2022 20:32:59 +0000 Subject: [PATCH 08/15] some overdue cleanup --- .../test_instance_norm_nvfuser.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py index c976c767f..309f3a1bc 100644 --- a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -69,9 +69,10 @@ def test_sweep(self): if torch.cuda.get_device_capability() >= (8, 0): dtypes.append(torch.bfloat16) for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)): - self.dtype = dtype - self.track_running_stats = track_running_stats - self.channels_last = channels_last - self.affine = affine - self.init_modules() - self.check_same_output() + with self.subTest(dtype=dtype, track_running_stats=track_running_stats, channels_last=channels_last, affine=affine): + self.dtype = dtype + self.track_running_stats = track_running_stats + self.channels_last = channels_last + self.affine = affine + self.init_modules() + self.check_same_output() From 62a2ff938f28867ef03771d7cca6469969cd7a7c Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 20 Sep 2022 20:24:03 +0000 Subject: [PATCH 09/15] fix device for dummy tensor --- apex/normalization/instance_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apex/normalization/instance_norm.py b/apex/normalization/instance_norm.py index 48870b33e..7f9ee1fd4 100644 --- a/apex/normalization/instance_norm.py +++ b/apex/normalization/instance_norm.py @@ -91,7 +91,7 @@ def __init__( factory_kwargs = {'device': device, 'dtype': dtype} super(_InstanceNormNVFuser, self).__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) - self.dummy = torch.empty([], device='cuda') + self.dummy = torch.empty([], device=device) def _check_input_dim(self, input): raise NotImplementedError From 988cafe4d3c016c9df95f539afea4dc9b5ebf1df Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 20 Sep 2022 22:24:29 +0000 Subject: [PATCH 10/15] add test for multigpu instancenorm3dnvfuser --- .../test_instance_norm_nvfuser.py | 28 +++++++++++++++++-- tests/L0/run_test.py | 2 ++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py index 309f3a1bc..47974a27c 100644 --- a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -2,6 +2,7 @@ import unittest import torch +import torch.nn as nn import apex from apex.normalization import InstanceNorm3dNVFuser @@ -59,8 +60,10 @@ def check_same_output(self): else: torch.testing.assert_close(self.m.weight.grad, self.reference_m.weight.grad) if self.m.bias is not None: - if self.dtype in (torch.float16, torch.bfloat16): - torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad, atol=5e-3, rtol=5e-3) + if self.dtype == torch.float16: + torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad, atol=5e-3, rtol=7e-2) + elif self.dtype == torch.bfloat16: + torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad, atol=5e-2, rtol=1e-2) else: torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad) @@ -76,3 +79,24 @@ def test_sweep(self): self.affine = affine self.init_modules() self.check_same_output() + + @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required") + def test_multigpu(self): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.norm = InstanceNorm3dNVFuser(4) + + def forward(self, x): + x = self.norm(x) + x = torch.sum(x, dim=(1, 2, 3, 4)) + return x + + device = torch.device(f"cuda:1") + model = Model().to(device) + + x = torch.randn(2, 4, 128, 128, 128, device=device, requires_grad=True) + y = torch.randn(2, device=device) + pred = model(x) + loss = nn.functional.mse_loss(pred, y.float()) + loss.backward() diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index 675d6bfe9..f460d0c28 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -25,12 +25,14 @@ "run_fused_layer_norm", "run_mlp", "run_transformer", + "run_instance_norm_nvfuser", ] DEFAULT_TEST_DIRS = [ "run_optimizers", "run_fused_layer_norm", "run_mlp", "run_transformer", + "run_instance_norm_nvfuser", ] From 324cee00dca82ac6a88266371383b538f3f44233 Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 6 Jan 2023 11:54:07 -0800 Subject: [PATCH 11/15] Update instance_norm.py --- apex/normalization/instance_norm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apex/normalization/instance_norm.py b/apex/normalization/instance_norm.py index 7f9ee1fd4..ce76ea6dd 100644 --- a/apex/normalization/instance_norm.py +++ b/apex/normalization/instance_norm.py @@ -129,6 +129,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def forward(self, input: Tensor) -> Tensor: assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" self._check_input_dim(input) + if self.dummy.device != input.device: + self.dummy = torch.empty([], device=input.device) if self.running_mean is not None: out = InstanceNormNVFuserFunction.apply( input, self.weight if self.weight is not None else self.dummy, From 768406d67b1b552345e8d64664286780ff10975e Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Feb 2023 14:47:32 -0800 Subject: [PATCH 12/15] support refactored nvfuser Signed-off-by: Masaki Kozuki --- csrc/instance_norm_nvfuser_kernel.cu | 16 +++++++++++----- setup.py | 24 +++++++++++++++++++----- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu index 6d9559d07..0f1add6fe 100644 --- a/csrc/instance_norm_nvfuser_kernel.cu +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -5,10 +5,16 @@ #include +// The following header file is found in `PYTORCH_HOME` +#include + +#if NVFUSER_THIRDPARTY +#include +#include +#else #include #include - -#include +#endif using namespace torch::jit::fuser::cuda; using namespace at::indexing; @@ -85,7 +91,7 @@ std::vector instance_norm_nvfuser_forward( } InstanceNormKey forward_key; setKey(input, weight, run_mean, channels_last, forward_key); - if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) { + if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -130,7 +136,7 @@ std::vector instance_norm_nvfuser_forward( if (!run_mean.sizes().size()) { _running_mean = nullptr; _running_var = nullptr; - } + } if (!weight.sizes().size()) { _weight = nullptr; _bias = nullptr; @@ -235,7 +241,7 @@ std::vector instance_norm_nvfuser_backward( if (!run_mean.sizes().size()) { _running_mean = nullptr; _running_var = nullptr; - } + } if (!weight.sizes().size()) { _weight = nullptr; } diff --git a/setup.py b/setup.py index 8f31779aa..5bea7292d 100644 --- a/setup.py +++ b/setup.py @@ -361,13 +361,27 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) if PYTORCH_HOME is not None and os.path.exists(PYTORCH_HOME): + nvfuser_is_refactored = "nvfuser" in ( + os.path.join(d) for d in os.listdir(os.path.join(PYTORCH_HOME, "third_party")) + if os.path.isdir(os.path.join(os.path.join(PYTORCH_HOME, "third_party"), d)) + ) print(PYTORCH_HOME) + include_dirs = [PYTORCH_HOME] + if nvfuser_is_refactored: + include_dirs.append(os.path.join(PYTORCH_HOME, "third_party/nvfuser/csrc")) ext_modules.append( - CUDAExtension('instance_norm_nvfuser_cuda', - ['csrc/instance_norm_nvfuser.cpp', 'csrc/instance_norm_nvfuser_kernel.cu'], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + [f"-I {PYTORCH_HOME}"])}, - ) + CUDAExtension( + name='instance_norm_nvfuser_cuda', + sources=[ + 'csrc/instance_norm_nvfuser.cpp', + 'csrc/instance_norm_nvfuser_kernel.cu', + ], + include_dirs=include_dirs, + extra_compile_args={ + "cxx": ["-O3"] + version_dependent_macros, + "nvcc": ["-O3"] + version_dependent_macros + [f"-DNVFUSER_THIRDPARTY={int(nvfuser_is_refactored)}"], + }, + ) ) if "--permutation_search" in sys.argv: From ecca2f77c3615b7afd91fb3dc733658e56e18981 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Feb 2023 14:48:54 -0800 Subject: [PATCH 13/15] unittest.main Signed-off-by: Masaki Kozuki --- .../test_instance_norm_nvfuser.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py index 47974a27c..dca3c2802 100644 --- a/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py +++ b/tests/L0/run_instance_norm_nvfuser/test_instance_norm_nvfuser.py @@ -7,6 +7,7 @@ import apex from apex.normalization import InstanceNorm3dNVFuser + class TestInstanceNormNVFuser(unittest.TestCase): dtype = torch.float track_running_stats = False @@ -21,7 +22,7 @@ def init_modules(self): self.reference_m = torch.nn.InstanceNorm3d(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype) def check_same_output(self): - torch.manual_seed(42) + torch.manual_seed(42) for i in range(2): # exercise JIT + caching inp = torch.randint(0, 2, (self.batch_size, self.channel_size, self.spatial_size, self.spatial_size, self.spatial_size), device='cuda', requires_grad=True, dtype=self.dtype) inp2 = inp.detach().clone() @@ -78,7 +79,7 @@ def test_sweep(self): self.channels_last = channels_last self.affine = affine self.init_modules() - self.check_same_output() + self.check_same_output() @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required") def test_multigpu(self): @@ -100,3 +101,7 @@ def forward(self, x): pred = model(x) loss = nn.functional.mse_loss(pred, y.float()) loss.backward() + + +if __name__ == "__main__": + unittest.main() From 7642c1c7d30de439feb35c9da9a3abaac324b85f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 11 Feb 2023 16:12:21 -0800 Subject: [PATCH 14/15] explicit path of third_party nvfuser Signed-off-by: Masaki Kozuki --- csrc/instance_norm_nvfuser_kernel.cu | 1 + setup.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/csrc/instance_norm_nvfuser_kernel.cu b/csrc/instance_norm_nvfuser_kernel.cu index 0f1add6fe..077232e6b 100644 --- a/csrc/instance_norm_nvfuser_kernel.cu +++ b/csrc/instance_norm_nvfuser_kernel.cu @@ -9,6 +9,7 @@ #include #if NVFUSER_THIRDPARTY +#include #include #include #else diff --git a/setup.py b/setup.py index 5bea7292d..bd7ac6e05 100644 --- a/setup.py +++ b/setup.py @@ -365,10 +365,15 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int os.path.join(d) for d in os.listdir(os.path.join(PYTORCH_HOME, "third_party")) if os.path.isdir(os.path.join(os.path.join(PYTORCH_HOME, "third_party"), d)) ) + import nvfuser # NOQA print(PYTORCH_HOME) include_dirs = [PYTORCH_HOME] + library_dirs = [] + extra_link_args = [] if nvfuser_is_refactored: include_dirs.append(os.path.join(PYTORCH_HOME, "third_party/nvfuser/csrc")) + library_dirs = nvfuser.__path__ + extra_link_args.append("-lnvfuser") ext_modules.append( CUDAExtension( name='instance_norm_nvfuser_cuda', @@ -377,6 +382,8 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int 'csrc/instance_norm_nvfuser_kernel.cu', ], include_dirs=include_dirs, + library_dirs=library_dirs, + extra_link_args=extra_link_args, extra_compile_args={ "cxx": ["-O3"] + version_dependent_macros, "nvcc": ["-O3"] + version_dependent_macros + [f"-DNVFUSER_THIRDPARTY={int(nvfuser_is_refactored)}"], From 0da3ffb92ee6fbe5336602f0e3989db1cd16f880 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 11 Feb 2023 21:38:39 -0800 Subject: [PATCH 15/15] use `nvfuser_codegen` --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bd7ac6e05..54a53d198 100644 --- a/setup.py +++ b/setup.py @@ -373,7 +373,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int if nvfuser_is_refactored: include_dirs.append(os.path.join(PYTORCH_HOME, "third_party/nvfuser/csrc")) library_dirs = nvfuser.__path__ - extra_link_args.append("-lnvfuser") + extra_link_args.append("-lnvfuser_codegen") ext_modules.append( CUDAExtension( name='instance_norm_nvfuser_cuda',