Skip to content

Commit

Permalink
address comments, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and crcrpar committed Apr 19, 2023
1 parent 791d815 commit a766a1b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 41 deletions.
44 changes: 6 additions & 38 deletions csrc/instance_norm_nvfuser_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,7 @@
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>

// Hashing machinery for Params
// Fowler–Noll–Vo hash function
// see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Params>
struct ParamsHash {
// Params must be a POD because we read out its memory
// contenst as char* when hashing
static_assert(std::is_pod<Params>::value, "Params is not POD");

size_t operator()(const Params& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&params);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < (int)sizeof(Params); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return (size_t)value;
}
};

template <typename Params>
struct ParamsEqual {
// Params must be a POD because we read out its memory
// contenst as char* when comparing
static_assert(std::is_pod<Params>::value, "Params is not POD");

bool operator()(const Params& a, const Params& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
}
};
#include <aten/src/ATen/native/utils/ParamsHash.h>

using namespace torch::jit::fuser::cuda;
using namespace at::indexing;
Expand Down Expand Up @@ -78,7 +47,7 @@ auto get_dtype(c10::ScalarType dtype) {

// TODO: doesn't support all combinations of dtype e.g., bias, run_var, ..
// bias is assumed to match weight, run_var is assumed to match run_mean
void getKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) {
void setKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) {
memset(&key, 0, sizeof(InstanceNormKey));
key.input_dtype = input.scalar_type();// static_cast<int8_t>(input.scalar_type());
key.weight_dtype = weight.scalar_type();
Expand All @@ -90,8 +59,8 @@ void getKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor&
key.affine = weight.sizes().size() ? true : false;
}

std::unordered_map<InstanceNormKey, std::unique_ptr<FusionExecutorCache>, ParamsHash<InstanceNormKey>, ParamsEqual<InstanceNormKey> > forward_fusion_cache;
std::unordered_map<InstanceNormKey, std::unique_ptr<FusionExecutorCache>, ParamsHash<InstanceNormKey>, ParamsEqual<InstanceNormKey> > backward_fusion_cache;
std::unordered_map<InstanceNormKey, std::unique_ptr<FusionExecutorCache>, at::native::ParamsHash<InstanceNormKey>, at::native::ParamsEqual<InstanceNormKey> > forward_fusion_cache;
std::unordered_map<InstanceNormKey, std::unique_ptr<FusionExecutorCache>, at::native::ParamsHash<InstanceNormKey>, at::native::ParamsEqual<InstanceNormKey> > backward_fusion_cache;

std::vector<at::Tensor> instance_norm_nvfuser_forward(
at::Tensor input,
Expand All @@ -104,8 +73,7 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
const float eps,
const bool channels_last) {
InstanceNormKey forward_key;
memset(&forward_key, 0, sizeof(InstanceNormKey));
getKey(input, weight, run_mean, channels_last, eps, forward_key);
setKey(input, weight, run_mean, channels_last, eps, forward_key);
if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down Expand Up @@ -166,7 +134,7 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
fusion->addOutput(result.mean);
fusion->addOutput(result.invstd);
}
forward_fusion_cache.emplace(forward_key, std::make_unique<FusionExecutorCache>(std::move(fusion))); // need std::move right
forward_fusion_cache.emplace(forward_key, std::make_unique<FusionExecutorCache>(std::move(fusion)));
}
std::vector<torch::jit::IValue> aten_inputs = {input, weight, bias};
if (run_mean.sizes().size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestInstanceNormNVFuser(unittest.TestCase):
channel_size = 7
spatial_size = 3

def setUp(self):
def init_modules(self):
self.m = InstanceNorm3dNVFuser(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype)
self.reference_m = torch.nn.InstanceNorm3d(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype)

Expand Down Expand Up @@ -60,10 +60,13 @@ def check_same_output(self):
torch.testing.assert_close(self.m.bias.grad, self.reference_m.bias.grad)

def test_sweep(self):
for dtype, track_running_stats, channels_last, affine in itertools.product((torch.float, torch.half), (False, True), (False, True), (False, True)):
dtypes = [torch.float, torch.half]
if torch.cuda.get_device_capability() >= (8, 0):
dtypes.append(torch.bfloat16)
for dtype, track_running_stats, channels_last, affine in itertools.product(dtypes, (False, True), (False, True), (False, True)):
self.dtype = dtype
self.track_running_stats = track_running_stats
self.channels_last = channels_last
self.affine = affine
self.setUp()
self.init_modules()
self.check_same_output()

0 comments on commit a766a1b

Please sign in to comment.