forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
commit 0da3ffb Author: Masaki Kozuki <[email protected]> Date: Sat Feb 11 21:38:39 2023 -0800 use `nvfuser_codegen` commit 7642c1c Author: Masaki Kozuki <[email protected]> Date: Sat Feb 11 16:12:21 2023 -0800 explicit path of third_party nvfuser Signed-off-by: Masaki Kozuki <[email protected]> commit ecca2f7 Author: Masaki Kozuki <[email protected]> Date: Thu Feb 9 14:48:54 2023 -0800 unittest.main Signed-off-by: Masaki Kozuki <[email protected]> commit 768406d Author: Masaki Kozuki <[email protected]> Date: Thu Feb 9 14:47:32 2023 -0800 support refactored nvfuser Signed-off-by: Masaki Kozuki <[email protected]> commit 324cee0 Author: eqy <[email protected]> Date: Fri Jan 6 11:54:07 2023 -0800 Update instance_norm.py commit 988cafe Author: Eddie Yan <[email protected]> Date: Tue Sep 20 22:24:29 2022 +0000 add test for multigpu instancenorm3dnvfuser commit 62a2ff9 Author: Eddie Yan <[email protected]> Date: Tue Sep 20 20:24:03 2022 +0000 fix device for dummy tensor commit d94d55a Author: Eddie Yan <[email protected]> Date: Mon Jul 25 20:32:59 2022 +0000 some overdue cleanup commit 57382e9 Author: Eddie Yan <[email protected]> Date: Mon Jul 25 19:50:11 2022 +0000 retab commit 1fb0512 Author: Eddie Yan <[email protected]> Date: Tue Mar 15 23:48:19 2022 +0000 add profile, remove scalars from cache key commit ccd652d Author: Eddie Yan <[email protected]> Date: Thu Mar 10 02:57:13 2022 +0000 sketchy test numerics twiddling commit b893d03 Author: Eddie Yan <[email protected]> Date: Thu Mar 10 01:50:38 2022 +0000 fix commit a766a1b Author: Eddie Yan <[email protected]> Date: Thu Mar 10 01:30:08 2022 +0000 address comments, cleanup commit 791d815 Author: Eddie Yan <[email protected]> Date: Tue Mar 1 19:28:58 2022 +0000 add weight and bias check commit 5edf81b Author: Eddie Yan <[email protected]> Date: Fri Feb 25 23:37:44 2022 +0000 initial check in
- Loading branch information
Showing
8 changed files
with
615 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm | ||
from .instance_norm import InstanceNorm3dNVFuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import importlib | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn.modules.batchnorm import _NormBase | ||
|
||
global instance_norm_nvfuser_cuda | ||
instance_norm_nvfuser_cuda = None | ||
|
||
class InstanceNormNVFuserFunction(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, weight, bias, running_mean, running_var, | ||
use_input_stats, momentum, eps): | ||
global instance_norm_nvfuser_cuda | ||
if instance_norm_nvfuser_cuda is None: | ||
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda") | ||
|
||
channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous(memory_format=torch.channels_last_3d) | ||
if channels_last: | ||
order = [0] + [i for i in range(2, len(input.shape))] + [1] | ||
_input = input.permute(order) | ||
else: | ||
_input = input | ||
assert _input.is_contiguous() | ||
result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var, | ||
use_input_stats, momentum, eps, channels_last) | ||
if len(result) == 3: | ||
out, mean, invstd = result | ||
else: | ||
running_mean, running_var, out, mean, invstd = result | ||
ctx.use_input_stats = use_input_stats | ||
ctx.eps = eps | ||
ctx.channels_last = channels_last | ||
# saving for backward in "explicit channels-last format" | ||
ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd) | ||
if channels_last: | ||
order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)] | ||
out = out.permute(order) | ||
if len(out.shape) == 4: | ||
assert out.is_contiguous(memory_format=torch.channels_last) | ||
assert input.is_contiguous(memory_format=torch.channels_last) | ||
elif len(out.shape) == 5: | ||
assert out.is_contiguous(memory_format=torch.channels_last_3d) | ||
assert input.is_contiguous(memory_format=torch.channels_last_3d) | ||
else: | ||
assert False, "unhandled channels_last format variation in forward" | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
global instance_norm_nvfuser_cuda | ||
if instance_norm_nvfuser_cuda is None: | ||
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda") | ||
|
||
if ctx.channels_last: | ||
order = [0] + [i for i in range(2, len(grad_output.shape))] + [1] | ||
grad_output = grad_output.permute(order) | ||
# input was saved in "explicit channels-last format" | ||
assert ctx.saved_tensors[0].is_contiguous() | ||
grad_output = grad_output.contiguous() | ||
saved = list(ctx.saved_tensors) | ||
saved.insert(1, grad_output) | ||
running_mean = saved[3] | ||
running_var = saved[4] | ||
mean = saved[-2] | ||
var = saved[-1] | ||
grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward(*saved, ctx.use_input_stats, ctx.eps, ctx.channels_last) | ||
if ctx.channels_last: | ||
order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)] | ||
grad_input = grad_input.permute(order) | ||
if len(grad_input.shape) == 4: | ||
assert grad_input.is_contiguous(memory_format=torch.channels_last) | ||
elif len(grad_input.shape) == 5: | ||
assert grad_input.is_contiguous(memory_format=torch.channels_last_3d) | ||
else: | ||
assert False, "unhandled channels_last format variation in backward" | ||
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None | ||
|
||
|
||
class _InstanceNormNVFuser(_NormBase): | ||
def __init__( | ||
self, | ||
num_features: int, | ||
eps: float = 1e-5, | ||
momentum: float = 0.1, | ||
affine: bool = False, | ||
track_running_stats: bool = False, | ||
device=None, | ||
dtype=None | ||
) -> None: | ||
factory_kwargs = {'device': device, 'dtype': dtype} | ||
super(_InstanceNormNVFuser, self).__init__( | ||
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) | ||
self.dummy = torch.empty([], device=device) | ||
|
||
def _check_input_dim(self, input): | ||
raise NotImplementedError | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
version = local_metadata.get('version', None) | ||
# at version 1: removed running_mean and running_var when | ||
# track_running_stats=False (default) | ||
if version is None and not self.track_running_stats: | ||
running_stats_keys = [] | ||
for name in ('running_mean', 'running_var'): | ||
key = prefix + name | ||
if key in state_dict: | ||
running_stats_keys.append(key) | ||
if len(running_stats_keys) > 0: | ||
error_msgs.append( | ||
'Unexpected running stats buffer(s) {names} for {klass} ' | ||
'with track_running_stats=False. If state_dict is a ' | ||
'checkpoint saved before 0.4.0, this may be expected ' | ||
'because {klass} does not track running stats by default ' | ||
'since 0.4.0. Please remove these keys from state_dict. If ' | ||
'the running stats are actually needed, instead set ' | ||
'track_running_stats=True in {klass} to enable them. See ' | ||
'the documentation of {klass} for details.' | ||
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys), | ||
klass=self.__class__.__name__)) | ||
for key in running_stats_keys: | ||
state_dict.pop(key) | ||
|
||
super(_InstanceNormNVFuser, self)._load_from_state_dict( | ||
state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs) | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" | ||
self._check_input_dim(input) | ||
if self.dummy.device != input.device: | ||
self.dummy = torch.empty([], device=input.device) | ||
if self.running_mean is not None: | ||
out = InstanceNormNVFuserFunction.apply( | ||
input, self.weight if self.weight is not None else self.dummy, | ||
self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var, | ||
self.training or not self.track_running_stats, self.momentum, self.eps) | ||
else: | ||
out = InstanceNormNVFuserFunction.apply( | ||
input, self.weight if self.weight is not None else self.dummy, | ||
self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy, | ||
self.training or not self.track_running_stats, self.momentum, self.eps) | ||
return out | ||
|
||
class InstanceNorm3dNVFuser(_InstanceNormNVFuser): | ||
def _check_input_dim(self, input): | ||
if input.dim() != 5: | ||
raise ValueError('expected 5D input (got {}D input)' | ||
.format(input.dim())) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#include <iostream> | ||
#include <vector> | ||
|
||
#include <torch/extension.h> | ||
|
||
std::vector<at::Tensor> instance_norm_nvfuser_forward( | ||
at::Tensor input, | ||
at::Tensor weight, | ||
at::Tensor bias, | ||
at::Tensor run_mean, | ||
at::Tensor run_var, | ||
const bool use_input_stats, | ||
const float momentum, | ||
const float eps, | ||
const bool channels_last = false | ||
); | ||
|
||
std::vector<at::Tensor> instance_norm_nvfuser_backward( | ||
at::Tensor input, | ||
at::Tensor grad_output, | ||
at::Tensor weight, | ||
at::Tensor running_mean, | ||
at::Tensor running_var, | ||
at::Tensor save_mean, | ||
at::Tensor save_invstd, | ||
const bool use_input_stats, | ||
const float eps, | ||
// const std::vector<bool>& output_mask, | ||
bool channels_last = false | ||
); | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &instance_norm_nvfuser_forward, "instance_norm forward (CUDA)"); | ||
m.def("backward", &instance_norm_nvfuser_backward, "instance_norm backward (CUDA)"); | ||
} |
Oops, something went wrong.