Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,24 @@ struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
}
}; // struct GroupNormAttrs

/*! \brief Attributes used in instance_norm operator */
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
int channel_axis;
Array<Integer> axes;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(InstanceNormAttrs, "relax.attrs.InstanceNormAttrs") {
TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel.");
TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied.");
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).describe(
"Indicating if the beta offset will be added to the normalized tensor.");
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
}
}; // struct InstanceNormAttrs

/*! \brief Attributes used in rms_norm operator */
struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
Array<Integer> axes;
Expand Down
87 changes: 84 additions & 3 deletions include/tvm/topi/nn/instance_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#define TVM_TOPI_NN_INSTANCE_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/nn/layer_norm.h>
#include <tvm/topi/tags.h>

#include <string>
Expand All @@ -43,6 +42,7 @@ using namespace tvm::te;
* d_{axis_k} == r_k
* \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param channel_axis The axis of the channel dimension
* \param axis The axis to normalize over (the axis along which mean and variance are
* computed).
* \param epsilon The epsilon value to avoid division by zero.
Expand All @@ -51,9 +51,90 @@ using namespace tvm::te;
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
const Array<Integer>& axis, double epsilon,
int channel_axis, const Array<Integer>& axis, double epsilon,
std::string name = "T_instance_norm", std::string tag = kInjective) {
return layer_norm(data, gamma, beta, axis, epsilon, name, tag);
const auto& data_type = data->dtype;
const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type;
const auto& beta_type = beta.defined() ? beta->dtype : data_type;
ICHECK(data_type == gamma_type && data_type == beta_type)
<< "instance_norm: data, gamma and beta must have the same type";
ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
<< "instance_norm: only support float32 and float16 for now";
bool is_float16 = data_type == DataType::Float(16);
// sum x and x^2
auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape =
MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true);
auto func = MakeTupleSumReducer();

auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
&data](const Array<Var>& indices) {
Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;

for (size_t i = 0; i < ndim; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
// real_axis contains i
eval_range.push_back(reduce_axes[red_counter]);
red_counter++;
} else {
eval_range.push_back(indices[arg_counter]);
arg_counter++;
}
}
auto square = [is_float16](const PrimExpr& x) {
if (is_float16) {
return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x);
}
return x * x;
};
if (is_float16) {
return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))},
reduce_axes, nullptr);
} else {
return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr);
}
};

auto temp_x_x2 =
tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);

auto temp_x = temp_x_x2[0];
auto temp_x2 = temp_x_x2[1];

auto reduce_extent = make_const(data->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
auto instance_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;

for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}
Var channel;
channel = indices[channel_axis];
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon));
if (is_float16) {
instance_norm = Cast(DataType::Float(16), instance_norm);
}
instance_norm = topi::multiply(instance_norm, gamma(channel));
if (beta.defined()) {
instance_norm = topi::add(instance_norm, beta(channel));
}
return instance_norm;
};
return tvm::te::compute(data->shape, instance_norm_func, name, tag);
}

} // namespace nn
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,29 @@ def _zeros(self, node: fx.Node) -> relax.Var:
)
return self.block_builder.emit(relax.op.zeros(size, dtype))

def _instance_norm(self, node: fx.Node):
import numpy as np

x = self.env[node.args[0]]
channel = int(self.shape_of(x)[1])
dtype = x.struct_info.dtype
gamma = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype))
beta = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype))
eps = node.args[4] if node.args[4] else 1e-05
channel_axis = 1
dim = len(self.shape_of(x))

return self.block_builder.emit(
relax.op.nn.instance_norm(
x,
gamma,
beta,
channel_axis=channel_axis,
axes=list(range(2, dim)),
epsilon=eps,
)
)

########## Others ##########

def create_convert_map(
Expand Down Expand Up @@ -447,6 +470,7 @@ def create_convert_map(
self.env[node.args[1]], self.env[node.args[0]]
),
"group_norm.default": self._group_norm,
"instance_norm.default": self._instance_norm,
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"max_pool1d.default": self._max_pool1d,
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,36 @@ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))

def _instance_norm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]

if module.affine:
weight = self.params[module.weight]
bias = self.params[module.bias]
else:
import numpy as np

dtype = x.struct_info.dtype
channel = int(self.shape_of(x)[1])
weight = relax.const(np.ones(channel), dtype=dtype)
bias = relax.const(np.zeros(channel), dtype=dtype)

eps = module.eps
channel_axis = 1
dim = len(self.shape_of(x))

return self.block_builder.emit(
relax.op.nn.instance_norm(
x,
weight,
bias,
channel_axis=channel_axis,
axes=list(range(2, dim)),
epsilon=eps,
)
)

def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
Expand Down Expand Up @@ -733,6 +763,9 @@ def create_convert_map(
nn.AvgPool2d: self._avg_pool2d_module,
nn.AvgPool3d: self._avg_pool3d_module,
nn.BatchNorm2d: self._batch_norm_2d_module,
nn.InstanceNorm1d: self._instance_norm,
nn.InstanceNorm2d: self._instance_norm,
nn.InstanceNorm3d: self._instance_norm,
nn.Conv1d: self._conv1d_module,
nn.Conv2d: self._conv2d_module,
nn.Conv3d: self._conv3d_module,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
gelu,
gelu_tanh,
group_norm,
instance_norm,
layer_norm,
leakyrelu,
log_softmax,
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,61 @@ def group_norm(
)


def instance_norm(
data: Expr,
gamma: Expr,
beta: Expr,
channel_axis: int,
axes: List[int],
epsilon: float = 1e-5,
center: bool = True,
scale: bool = True,
) -> Expr:
r"""
Instance normalization

Parameters
----------
data : relax.Expr
Input to which instance_norm will be applied.

gamma : relax.Expr
The gamma scale factor.

beta : relax.Expr
The beta offset factor.

axes : Union[int, List[int]]
The axes that along which the normalization is applied.

epsilon : float
Small float added to variance to avoid dividing by zero.

center : bool
Indicating if the beta offset will be added to the normalized tensor.

scale : bool
Indicating if the gamma scale will be multiplied.

Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axes, int):
axes = [axes]
return _ffi_api.instance_norm( # type: ignore
data,
gamma,
beta,
channel_axis,
axes,
epsilon,
center,
scale,
)


def rms_norm(
data: Expr,
weight: Expr,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class LayerNormAttrs(Attrs):
"""Attributes used in layer_norm operator"""


@tvm.ffi.register_object("relax.attrs.InstanceNormAttrs")
class InstanceNormAttrs(Attrs):
"""Attributes used in instance_norm operator"""


@tvm.ffi.register_object("relax.attrs.DropoutAttrs")
class DropoutAttrs(Attrs):
"""Attributes for dropout operator"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,19 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.instance_norm")
def _nn_instance_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.instance_norm,
data=call.args[0],
gamma=call.args[1],
beta=call.args[2],
channel_axis=call.attrs.channel_axis,
axis=call.attrs.axes,
epsilon=call.attrs.epsilon,
)


@register_legalize("relax.nn.rms_norm")
def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/nn/instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .. import cpp


def instance_norm(data, gamma, beta, axis, epsilon=1e-5):
def instance_norm(data, gamma, beta, channel_axis, axis, epsilon=1e-5):
"""Instance normalization operator.

Parameters
Expand All @@ -44,4 +44,4 @@ def instance_norm(data, gamma, beta, axis, epsilon=1e-5):
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon)
return cpp.nn.instance_norm(data, gamma, beta, channel_axis, axis, epsilon)
Loading