Skip to content

Commit

Permalink
initial check in
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and crcrpar committed Apr 19, 2023
1 parent 817e818 commit 5edf81b
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 0 deletions.
1 change: 1 addition & 0 deletions apex/normalization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm
from .instance_norm import InstanceNorm3dNVFuser
149 changes: 149 additions & 0 deletions apex/normalization/instance_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import importlib

import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _NormBase

global instance_norm_nvfuser_cuda
instance_norm_nvfuser_cuda = None

class InstanceNormNVFuserFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_var,
use_input_stats, momentum, eps):
global instance_norm_nvfuser_cuda
if instance_norm_nvfuser_cuda is None:
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda")

channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous(memory_format=torch.channels_last_3d)
if channels_last:
order = [0] + [i for i in range(2, len(input.shape))] + [1]
_input = input.permute(order)
else:
_input = input
assert _input.is_contiguous()
result = instance_norm_nvfuser_cuda.forward(_input, weight, bias, running_mean, running_var,
use_input_stats, momentum, eps, channels_last)
if len(result) == 3:
out, mean, invstd = result
else:
running_mean, running_var, out, mean, invstd = result
ctx.use_input_stats = use_input_stats
ctx.eps = eps
ctx.channels_last = channels_last
# saving for backward in "explicit channels-last format"
ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd)
if channels_last:
order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)]
out = out.permute(order)
if len(out.shape) == 4:
assert out.is_contiguous(memory_format=torch.channels_last)
assert input.is_contiguous(memory_format=torch.channels_last)
elif len(out.shape) == 5:
assert out.is_contiguous(memory_format=torch.channels_last_3d)
assert input.is_contiguous(memory_format=torch.channels_last_3d)
else:
assert False, "unhandled channels_last format variation in forward"
return out

@staticmethod
def backward(ctx, grad_output):
global instance_norm_nvfuser_cuda
if instance_norm_nvfuser_cuda is None:
instance_norm_nvfuser_cuda = importlib.import_module("instance_norm_nvfuser_cuda")

if ctx.channels_last:
order = [0] + [i for i in range(2, len(grad_output.shape))] + [1]
grad_output = grad_output.permute(order)
# input was saved in "explicit channels-last format"
assert ctx.saved_tensors[0].is_contiguous()
grad_output = grad_output.contiguous()
saved = list(ctx.saved_tensors)
saved.insert(1, grad_output)
running_mean = saved[3]
running_var = saved[4]
mean = saved[-2]
var = saved[-1]
grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward(*saved, ctx.use_input_stats, ctx.eps, ctx.channels_last)
if ctx.channels_last:
order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)]
grad_input = grad_input.permute(order)
if len(grad_input.shape) == 4:
assert grad_input.is_contiguous(memory_format=torch.channels_last)
elif len(grad_input.shape) == 5:
assert grad_input.is_contiguous(memory_format=torch.channels_last_3d)
else:
assert False, "unhandled channels_last format variation in backward"
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None


class _InstanceNormNVFuser(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_InstanceNormNVFuser, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
self.dummy = torch.empty([], device='cuda')

def _check_input_dim(self, input):
raise NotImplementedError

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
# at version 1: removed running_mean and running_var when
# track_running_stats=False (default)
if version is None and not self.track_running_stats:
running_stats_keys = []
for name in ('running_mean', 'running_var'):
key = prefix + name
if key in state_dict:
running_stats_keys.append(key)
if len(running_stats_keys) > 0:
error_msgs.append(
'Unexpected running stats buffer(s) {names} for {klass} '
'with track_running_stats=False. If state_dict is a '
'checkpoint saved before 0.4.0, this may be expected '
'because {klass} does not track running stats by default '
'since 0.4.0. Please remove these keys from state_dict. If '
'the running stats are actually needed, instead set '
'track_running_stats=True in {klass} to enable them. See '
'the documentation of {klass} for details.'
.format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
klass=self.__class__.__name__))
for key in running_stats_keys:
state_dict.pop(key)

super(_InstanceNormNVFuser, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, input: Tensor) -> Tensor:
assert input.is_cuda, "NVFuser InstanceNorm is CUDA only"
self._check_input_dim(input)
if self.running_mean is not None:
out = InstanceNormNVFuserFunction.apply(
input, self.weight if self.weight is not None else self.dummy,
self.bias if self.bias is not None else self.dummy, self.running_mean, self.running_var,
self.training or not self.track_running_stats, self.momentum, self.eps)
else:
out = InstanceNormNVFuserFunction.apply(
input, self.weight if self.weight is not None else self.dummy,
self.bias if self.bias is not None else self.dummy, self.dummy, self.dummy,
self.training or not self.track_running_stats, self.momentum, self.eps)
return out

class InstanceNorm3dNVFuser(_InstanceNormNVFuser):
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))

35 changes: 35 additions & 0 deletions csrc/instance_norm_nvfuser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include <iostream>
#include <vector>

#include <torch/extension.h>

std::vector<at::Tensor> instance_norm_nvfuser_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor run_mean,
at::Tensor run_var,
const bool use_input_stats,
const float momentum,
const float eps,
const bool channels_last = false
);

std::vector<at::Tensor> instance_norm_nvfuser_backward(
at::Tensor input,
at::Tensor grad_output,
at::Tensor weight,
at::Tensor running_mean,
at::Tensor running_var,
at::Tensor save_mean,
at::Tensor save_invstd,
const bool use_input_stats,
const float eps,
// const std::vector<bool>& output_mask,
bool channels_last = false
);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &instance_norm_nvfuser_forward, "instance_norm forward (CUDA)");
m.def("backward", &instance_norm_nvfuser_backward, "instance_norm backward (CUDA)");
}
Loading

0 comments on commit 5edf81b

Please sign in to comment.