Skip to content

Commit

Permalink
add profile, remove scalars from cache key
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and crcrpar committed Apr 19, 2023
1 parent ccd652d commit 1fb0512
Showing 1 changed file with 54 additions and 12 deletions.
66 changes: 54 additions & 12 deletions csrc/instance_norm_nvfuser_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <iostream>
#include <map>
#include <vector>
#include <chrono>

#include <torch/extension.h>

Expand All @@ -12,6 +13,15 @@
using namespace torch::jit::fuser::cuda;
using namespace at::indexing;

std::chrono::time_point<std::chrono::steady_clock> t1;
std::chrono::time_point<std::chrono::steady_clock> t2;
std::chrono::time_point<std::chrono::steady_clock> t3;

bool profile() {
static bool should_profile = std::getenv("APEX_NVFUSER_PROFILE") != nullptr;
return should_profile;
}

// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
Expand All @@ -30,7 +40,6 @@ struct InstanceNormKey {
bool channels_last;
bool running_mean;
bool affine;
float eps;
};

auto get_dtype(c10::ScalarType dtype) {
Expand All @@ -47,14 +56,13 @@ auto get_dtype(c10::ScalarType dtype) {

// TODO: doesn't support all combinations of dtype e.g., bias, run_var, ..
// bias is assumed to match weight, run_var is assumed to match run_mean
void setKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, const float eps, InstanceNormKey& key) {
void setKey(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& run_mean, const bool channels_last, InstanceNormKey& key) {
memset(&key, 0, sizeof(InstanceNormKey));
key.input_dtype = input.scalar_type();// static_cast<int8_t>(input.scalar_type());
key.weight_dtype = weight.scalar_type();
key.mean_dtype = run_mean.scalar_type();
key.dim = input.sizes().size();
key.channels_last = channels_last;
key.eps = eps;
key.running_mean = run_mean.sizes().size() > 0;
key.affine = weight.sizes().size() ? true : false;
}
Expand All @@ -72,8 +80,11 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
const float momentum,
const float eps,
const bool channels_last) {
if (profile()) {
t1 = std::chrono::steady_clock::now();
}
InstanceNormKey forward_key;
setKey(input, weight, run_mean, channels_last, eps, forward_key);
setKey(input, weight, run_mean, channels_last, forward_key);
if (forward_fusion_cache.find(forward_key) == forward_fusion_cache.end()) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down Expand Up @@ -110,8 +121,10 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
// casting is done by Forward for running mean/var as it needs original inputs for aliasing
}

Double* _momentum = IrBuilder::create<Double>(momentum);
Double* _eps = IrBuilder::create<Double>(eps);
Double* _momentum = IrBuilder::create<Double>();
Double* _eps = IrBuilder::create<Double>();
fusion->addInput(_momentum);
fusion->addInput(_eps);

ForwardNormResult result;
if (!run_mean.sizes().size()) {
Expand Down Expand Up @@ -141,7 +154,21 @@ std::vector<at::Tensor> instance_norm_nvfuser_forward(
aten_inputs.push_back(run_mean);
aten_inputs.push_back(run_var);
}
return forward_fusion_cache[forward_key].get()->runFusionWithInputs(aten_inputs);
aten_inputs.push_back(momentum);
aten_inputs.push_back(eps);
if (profile()) {
t2 = std::chrono::steady_clock::now();
}
auto r = forward_fusion_cache[forward_key].get()->runFusionWithInputs(aten_inputs);
if (profile()) {
t3 = std::chrono::steady_clock::now();
std::chrono::duration<double> full = t3 - t1;
std::chrono::duration<double> pre = t2 - t1;
std::chrono::duration<double> exec = t3 - t2;
std::cout << "NVFuserInstanceNorm Forward (full, pre-exec, exec) (" << full.count()
<< ", " << pre.count() << ", " << exec.count() << ")" << std::endl;
}
return r;
}

std::vector<at::Tensor> instance_norm_nvfuser_backward(
Expand All @@ -157,9 +184,12 @@ std::vector<at::Tensor> instance_norm_nvfuser_backward(
// const std::vector<bool>& output_mask,
bool channels_last
) {
if (profile()) {
t1 = std::chrono::steady_clock::now();
}
InstanceNormKey backward_key;
memset(&backward_key, 0, sizeof(InstanceNormKey));
setKey(input, weight, run_mean, channels_last, eps, backward_key);
setKey(input, weight, run_mean, channels_last, backward_key);
if (backward_fusion_cache.find(backward_key) == backward_fusion_cache.end()) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down Expand Up @@ -200,8 +230,8 @@ std::vector<at::Tensor> instance_norm_nvfuser_backward(
_running_var = castOp(DataType::Float, _running_var);
}


Double* _eps = IrBuilder::create<Double>(eps);
Double* _eps = IrBuilder::create<Double>();
fusion->addInput(_eps);
if (!run_mean.sizes().size()) {
_running_mean = nullptr;
_running_var = nullptr;
Expand Down Expand Up @@ -232,6 +262,18 @@ std::vector<at::Tensor> instance_norm_nvfuser_backward(
backward_fusion_cache.emplace(backward_key, std::make_unique<FusionExecutorCache>(std::move(fusion)));
}
std::vector<torch::jit::IValue> aten_inputs = {
input, grad_output, weight, run_mean, run_var, save_mean, save_invstd};
return backward_fusion_cache[backward_key].get()->runFusionWithInputs(aten_inputs);
input, grad_output, weight, run_mean, run_var, save_mean, save_invstd, eps};
if (profile()) {
t2 = std::chrono::steady_clock::now();
}
auto r = backward_fusion_cache[backward_key].get()->runFusionWithInputs(aten_inputs);
if (profile()) {
t3 = std::chrono::steady_clock::now();
std::chrono::duration<double> full = t3 - t1;
std::chrono::duration<double> pre = t2 - t1;
std::chrono::duration<double> exec = t3 - t2;
std::cout << "NVFuserInstanceNorm Backward (full, pre-exec, exec) (" << full.count()
<< ", " << pre.count() << ", " << exec.count() << ")" << std::endl;
}
return r;
}

0 comments on commit 1fb0512

Please sign in to comment.