Skip to content

Commit

Permalink
Squashed commit of NVIDIA#1582
Browse files Browse the repository at this point in the history
commit 0da3ffb
Author: Masaki Kozuki <[email protected]>
Date:   Sat Feb 11 21:38:39 2023 -0800

    use `nvfuser_codegen`

commit 7642c1c
Author: Masaki Kozuki <[email protected]>
Date:   Sat Feb 11 16:12:21 2023 -0800

    explicit path of third_party nvfuser

    Signed-off-by: Masaki Kozuki <[email protected]>

commit ecca2f7
Author: Masaki Kozuki <[email protected]>
Date:   Thu Feb 9 14:48:54 2023 -0800

    unittest.main

    Signed-off-by: Masaki Kozuki <[email protected]>

commit 768406d
Author: Masaki Kozuki <[email protected]>
Date:   Thu Feb 9 14:47:32 2023 -0800

    support refactored nvfuser

    Signed-off-by: Masaki Kozuki <[email protected]>

commit 324cee0
Author: eqy <[email protected]>
Date:   Fri Jan 6 11:54:07 2023 -0800

    Update instance_norm.py

commit 988cafe
Author: Eddie Yan <[email protected]>
Date:   Tue Sep 20 22:24:29 2022 +0000

    add test for multigpu instancenorm3dnvfuser

commit 62a2ff9
Author: Eddie Yan <[email protected]>
Date:   Tue Sep 20 20:24:03 2022 +0000

    fix device for dummy tensor

commit d94d55a
Author: Eddie Yan <[email protected]>
Date:   Mon Jul 25 20:32:59 2022 +0000

    some overdue cleanup

commit 57382e9
Author: Eddie Yan <[email protected]>
Date:   Mon Jul 25 19:50:11 2022 +0000

    retab

commit 1fb0512
Author: Eddie Yan <[email protected]>
Date:   Tue Mar 15 23:48:19 2022 +0000

    add profile, remove scalars from cache key

commit ccd652d
Author: Eddie Yan <[email protected]>
Date:   Thu Mar 10 02:57:13 2022 +0000

    sketchy test numerics twiddling

commit b893d03
Author: Eddie Yan <[email protected]>
Date:   Thu Mar 10 01:50:38 2022 +0000

    fix

commit a766a1b
Author: Eddie Yan <[email protected]>
Date:   Thu Mar 10 01:30:08 2022 +0000

    address comments, cleanup

commit 791d815
Author: Eddie Yan <[email protected]>
Date:   Tue Mar 1 19:28:58 2022 +0000

    add weight and bias check

commit 5edf81b
Author: Eddie Yan <[email protected]>
Date:   Fri Feb 25 23:37:44 2022 +0000

    initial check in
  • Loading branch information
crcrpar committed Jul 25, 2023
1 parent 50ac842 commit bae1f93
Show file tree
Hide file tree
Showing 8 changed files with 615 additions and 0 deletions.
1 change: 1 addition & 0 deletions apex/normalization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm
from .instance_norm import InstanceNorm3dNVFuser
151 changes: 151 additions & 0 deletions apex/normalization/instance_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import importlib

import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _NormBase

global instance_norm_nvfuser_cuda
instance_norm_nvfuser_cuda = None

class InstanceNormNVFuserFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_var,
use_input_stats, momentum, eps):
global instance_norm_nvfuser_cuda
if instance_norm_nvfuser_cuda is None:
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda")

channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous(memory_format=torch.channels_last_3d)
if channels_last:
order = [0] + [i for i in range(2, len(input.shape))] + [1]
_input = input.permute(order)
else:
_input = input
assert _input.is_contiguous()
result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var,
use_input_stats, momentum, eps, channels_last)
if len(result) == 3:
out, mean, invstd = result
else:
running_mean, running_var, out, mean, invstd = result
ctx.use_input_stats = use_input_stats
ctx.eps = eps
ctx.channels_last = channels_last
# saving for backward in "explicit channels-last format"
ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd)
if channels_last:
order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)]
out = out.permute(order)
if len(out.shape) == 4:
assert out.is_contiguous(memory_format=torch.channels_last)
assert input.is_contiguous(memory_format=torch.channels_last)
elif len(out.shape) == 5:
assert out.is_contiguous(memory_format=torch.channels_last_3d)
assert input.is_contiguous(memory_format=torch.channels_last_3d)
else:
assert False, "unhandled channels_last format variation in forward"
return out

@staticmethod
def backward(ctx, grad_output):
global instance_norm_nvfuser_cuda
if instance_norm_nvfuser_cuda is None:
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda")

if ctx.channels_last:
order = [0] + [i for i in range(2, len(grad_output.shape))] + [1]
grad_output = grad_output.permute(order)
# input was saved in "explicit channels-last format"
assert ctx.saved_tensors[0].is_contiguous()
grad_output = grad_output.contiguous()
saved = list(ctx.saved_tensors)
saved.insert(1, grad_output)
running_mean = saved[3]
running_var = saved[4]
mean = saved[-2]
var = saved[-1]
grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward(*saved, ctx.use_input_stats, ctx.eps, ctx.channels_last)
if ctx.channels_last:
order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)]
grad_input = grad_input.permute(order)
if len(grad_input.shape) == 4:
assert grad_input.is_contiguous(memory_format=torch.channels_last)
elif len(grad_input.shape) == 5:
assert grad_input.is_contiguous(memory_format=torch.channels_last_3d)
else:
assert False, "unhandled channels_last format variation in backward"
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None


class _InstanceNormNVFuser(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_InstanceNormNVFuser, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
self.dummy = torch.empty([], device=device)

def _check_input_dim(self, input):
raise NotImplementedError

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
# at version 1: removed running_mean and running_var when
# track_running_stats=False (default)
if version is None and not self.track_running_stats:
running_stats_keys = []
for name in ('running_mean', 'running_var'):
key = prefix + name
if key in state_dict:
running_stats_keys.append(key)
if len(running_stats_keys) > 0:
error_msgs.append(
'Unexpected running stats buffer(s) {names} for {klass} '
'with track_running_stats=False. If state_dict is a '
'checkpoint saved before 0.4.0, this may be expected '
'because {klass} does not track running stats by default '
'since 0.4.0. Please remove these keys from state_dict. If '
'the running stats are actually needed, instead set '
'track_running_stats=True in {klass} to enable them. See '
'the documentation of {klass} for details.'
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
klass=self.__class__.__name__))
for key in running_stats_keys:
state_dict.pop(key)

super(_InstanceNormNVFuser, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, input: Tensor) -> Tensor:
assert input.is_cuda, "NVFuser InstanceNorm is CUDA only"
self._check_input_dim(input)
if self.dummy.device != input.device:
self.dummy = torch.empty([], device=input.device)
if self.running_mean is not None:
out = InstanceNormNVFuserFunction.apply(
input, self.weight if self.weight is not None else self.dummy,
self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var,
self.training or not self.track_running_stats, self.momentum, self.eps)
else:
out = InstanceNormNVFuserFunction.apply(
input, self.weight if self.weight is not None else self.dummy,
self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy,
self.training or not self.track_running_stats, self.momentum, self.eps)
return out

class InstanceNorm3dNVFuser(_InstanceNormNVFuser):
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))

35 changes: 35 additions & 0 deletions csrc/instance_norm_nvfuser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include <iostream>
#include <vector>

#include <torch/extension.h>

std::vector<at::Tensor> instance_norm_nvfuser_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor run_mean,
at::Tensor run_var,
const bool use_input_stats,
const float momentum,
const float eps,
const bool channels_last = false
);

std::vector<at::Tensor> instance_norm_nvfuser_backward(
at::Tensor input,
at::Tensor grad_output,
at::Tensor weight,
at::Tensor running_mean,
at::Tensor running_var,
at::Tensor save_mean,
at::Tensor save_invstd,
const bool use_input_stats,
const float eps,
// const std::vector<bool>& output_mask,
bool channels_last = false
);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &instance_norm_nvfuser_forward, "instance_norm forward (CUDA)");
m.def("backward", &instance_norm_nvfuser_backward, "instance_norm backward (CUDA)");
}
Loading

0 comments on commit bae1f93

Please sign in to comment.