diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 911a7d449e3b..b89a125b8118 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -534,6 +534,24 @@ struct GroupNormAttrs : public tvm::AttrsNode { } }; // struct GroupNormAttrs +/*! \brief Attributes used in instance_norm operator */ +struct InstanceNormAttrs : public tvm::AttrsNode { + int channel_axis; + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(InstanceNormAttrs, "relax.attrs.InstanceNormAttrs") { + TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); + TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct InstanceNormAttrs + /*! \brief Attributes used in rms_norm operator */ struct RMSNormAttrs : public tvm::AttrsNode { Array axes; diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 28b1a819a8ae..d400721215ec 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -25,7 +25,6 @@ #define TVM_TOPI_NN_INSTANCE_NORM_H_ #include -#include #include #include @@ -43,6 +42,7 @@ using namespace tvm::te; * d_{axis_k} == r_k * \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where * d_{axis_k} == r_k + * \param channel_axis The axis of the channel dimension * \param axis The axis to normalize over (the axis along which mean and variance are * computed). * \param epsilon The epsilon value to avoid division by zero. @@ -51,9 +51,90 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - const Array& axis, double epsilon, + int channel_axis, const Array& axis, double epsilon, std::string name = "T_instance_norm", std::string tag = kInjective) { - return layer_norm(data, gamma, beta, axis, epsilon, name, tag); + const auto& data_type = data->dtype; + const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; + const auto& beta_type = beta.defined() ? beta->dtype : data_type; + ICHECK(data_type == gamma_type && data_type == beta_type) + << "instance_norm: data, gamma and beta must have the same type"; + ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + << "instance_norm: only support float32 and float16 for now"; + bool is_float16 = data_type == DataType::Float(16); + // sum x and x^2 + auto ndim = data->shape.size(); + ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + auto real_axis = GetRealAxis(static_cast(ndim), axis); + auto reduce_axes = MakeReduceAxes(real_axis, data); + auto target_shape = + MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); + auto func = MakeTupleSumReducer(); + + auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, + &data](const Array& indices) { + Array eval_range; + int arg_counter = 0; + int red_counter = 0; + + for (size_t i = 0; i < ndim; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + // real_axis contains i + eval_range.push_back(reduce_axes[red_counter]); + red_counter++; + } else { + eval_range.push_back(indices[arg_counter]); + arg_counter++; + } + } + auto square = [is_float16](const PrimExpr& x) { + if (is_float16) { + return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x); + } + return x * x; + }; + if (is_float16) { + return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))}, + reduce_axes, nullptr); + } else { + return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); + } + }; + + auto temp_x_x2 = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce); + + auto temp_x = temp_x_x2[0]; + auto temp_x2 = temp_x_x2[1]; + + auto reduce_extent = make_const(data->dtype, 1); + for (int i : real_axis) { + reduce_extent *= data->shape[i]; + } + auto instance_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices; + + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + Var channel; + channel = indices[channel_axis]; + auto mean = temp_x(non_reduce_indices) / reduce_extent; + auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; + auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); + if (is_float16) { + instance_norm = Cast(DataType::Float(16), instance_norm); + } + instance_norm = topi::multiply(instance_norm, gamma(channel)); + if (beta.defined()) { + instance_norm = topi::add(instance_norm, beta(channel)); + } + return instance_norm; + }; + return tvm::te::compute(data->shape, instance_norm_func, name, tag); } } // namespace nn diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4e7c0bf324d6..6b4396621934 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -292,6 +292,29 @@ def _zeros(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.zeros(size, dtype)) + def _instance_norm(self, node: fx.Node): + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + gamma = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + beta = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + eps = node.args[4] if node.args[4] else 1e-05 + channel_axis = 1 + dim = len(self.shape_of(x)) + + return self.block_builder.emit( + relax.op.nn.instance_norm( + x, + gamma, + beta, + channel_axis=channel_axis, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + ########## Others ########## def create_convert_map( @@ -447,6 +470,7 @@ def create_convert_map( self.env[node.args[1]], self.env[node.args[0]] ), "group_norm.default": self._group_norm, + "instance_norm.default": self._instance_norm, "layer_norm.default": self._layer_norm, "linear.default": self._linear, "max_pool1d.default": self._max_pool1d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 33abccbe5f85..b15e406339c7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -280,6 +280,36 @@ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + def _instance_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.affine: + weight = self.params[module.weight] + bias = self.params[module.bias] + else: + import numpy as np + + dtype = x.struct_info.dtype + channel = int(self.shape_of(x)[1]) + weight = relax.const(np.ones(channel), dtype=dtype) + bias = relax.const(np.zeros(channel), dtype=dtype) + + eps = module.eps + channel_axis = 1 + dim = len(self.shape_of(x)) + + return self.block_builder.emit( + relax.op.nn.instance_norm( + x, + weight, + bias, + channel_axis=channel_axis, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -733,6 +763,9 @@ def create_convert_map( nn.AvgPool2d: self._avg_pool2d_module, nn.AvgPool3d: self._avg_pool3d_module, nn.BatchNorm2d: self._batch_norm_2d_module, + nn.InstanceNorm1d: self._instance_norm, + nn.InstanceNorm2d: self._instance_norm, + nn.InstanceNorm3d: self._instance_norm, nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 62fa0d53a93c..d12a3ee3636f 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -35,6 +35,7 @@ gelu, gelu_tanh, group_norm, + instance_norm, layer_norm, leakyrelu, log_softmax, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index c6beea315891..b193f93c0f85 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1744,6 +1744,61 @@ def group_norm( ) +def instance_norm( + data: Expr, + gamma: Expr, + beta: Expr, + channel_axis: int, + axes: List[int], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Instance normalization + + Parameters + ---------- + data : relax.Expr + Input to which instance_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.instance_norm( # type: ignore + data, + gamma, + beta, + channel_axis, + axes, + epsilon, + center, + scale, + ) + + def rms_norm( data: Expr, weight: Expr, diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 4869ff252065..dd30215ef654 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -94,6 +94,11 @@ class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" +@tvm.ffi.register_object("relax.attrs.InstanceNormAttrs") +class InstanceNormAttrs(Attrs): + """Attributes used in instance_norm operator""" + + @tvm.ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for dropout operator""" diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index f18ad6097f06..ed9802fc9e63 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -634,6 +634,19 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.instance_norm") +def _nn_instance_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.instance_norm, + data=call.args[0], + gamma=call.args[1], + beta=call.args[2], + channel_axis=call.attrs.channel_axis, + axis=call.attrs.axes, + epsilon=call.attrs.epsilon, + ) + + @register_legalize("relax.nn.rms_norm") def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/topi/nn/instance_norm.py b/python/tvm/topi/nn/instance_norm.py index d119b57bfdee..a64cd2d80cb4 100644 --- a/python/tvm/topi/nn/instance_norm.py +++ b/python/tvm/topi/nn/instance_norm.py @@ -18,7 +18,7 @@ from .. import cpp -def instance_norm(data, gamma, beta, axis, epsilon=1e-5): +def instance_norm(data, gamma, beta, channel_axis, axis, epsilon=1e-5): """Instance normalization operator. Parameters @@ -44,4 +44,4 @@ def instance_norm(data, gamma, beta, axis, epsilon=1e-5): result : tvm.te.Tensor N-D with shape (d_0, d_1, ..., d_{N-1}) """ - return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon) + return cpp.nn.instance_norm(data, gamma, beta, channel_axis, axis, epsilon) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b79690d3a9bd..6da83697ee15 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -415,7 +415,6 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call, } const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; - LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); // While dealing with sub layouts, its adviced to deal with batchnorm @@ -624,6 +623,106 @@ TVM_REGISTER_OP("relax.nn.group_norm") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.nn.instance_norm */ +TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); + +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, + double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->channel_axis = std::move(channel_axis); + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.instance_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.instance_norm").set_body_typed(instance_norm); + +StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + TensorStructInfo data_sinfo = input_sinfo[0]; + + int channel_axis = -1; + if (!data_sinfo->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + // channel_axis must not be in axes. + if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " expects that channel_axis must not be in axes, but got channel_axis: " + << channel_axis << ", axes: " << attrs->axes); + } + } + const auto* data_shape = data_sinfo->shape.as(); + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); + } + } + } + return data_sinfo; +} + +InferLayoutOutput InferLayoutInstanceNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + std::vector initial_layouts; + for (size_t i = 0; i < 3; ++i) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + } + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + std::vector new_axes; + for (const auto& axis : attrs->axes) { + new_axes.push_back(FindAxis(layout->layout, (axis->value))); + } + new_attrs->axes = std::move(new_axes); + new_attrs->channel_axis = FindAxis(layout->layout, attrs->channel_axis); + return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, + Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.instance_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoInstanceNorm) + .set_attr("FRelaxInferLayout", InferLayoutInstanceNorm) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.nn.rms_norm */ TVM_REGISTER_NODE_TYPE(RMSNormAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 018741430199..39f8c2d73800 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -90,6 +90,10 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, Array axes, double epsilon, bool center, bool scale); +/*! \brief Compute instance normalization. */ +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, + double epsilon, bool center, bool scale); + /*! \brief Compute root mean square normalization. */ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 7fef93550d37..4b2095a53868 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -221,8 +221,8 @@ TVM_FFI_REGISTER_GLOBAL("topi.nn.group_norm") TVM_FFI_REGISTER_GLOBAL("topi.nn.instance_norm") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), - args[4].cast()); + args[2].cast(), args[3].cast(), + args[4].cast>(), args[5].cast()); }); /* Ops from nn/rms_norm.h */ diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e6f75372d1b0..4c965bb6ffa8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2577,6 +2577,51 @@ def main( verify_model(model, example_args, binding, expected1) +def test_instancenorm2d(): + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class InstanceNorm2d(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.InstanceNorm2d(3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.instance_norm( + input_1, + w1, + w2, + channel_axis=1, + axes=[2, 3], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = InstanceNorm2d() + binding = { + "w1": torch.ones(3).detach().numpy(), + "w2": torch.zeros(3).detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + def test_layernorm(): class LayerNorm(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f33b55085825..00c61bd31f23 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2170,6 +2170,53 @@ def main( verify_model(model, input_info, binding, expected1) +def test_instancenorm2d(): + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class InstanceNorm2d(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.InstanceNorm2d(3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.instance_norm( + input_1, + w1, + w2, + channel_axis=1, + axes=[2, 3], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = InstanceNorm2d() + binding = { + "w1": torch.ones(3).detach().numpy(), + "w2": torch.zeros(3).detach().numpy(), + } + verify_model(model, input_info, binding, expected1) + + operator_binary_1 = [ (operator.add, R.add), (operator.sub, R.subtract),