Skip to content

Commit 683e7a4

Browse files
authored
[TOPI] Add instance_norm operator (#14410)
1 parent 776cf5b commit 683e7a4

File tree

7 files changed

+235
-0
lines changed

7 files changed

+235
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \brief instance normalization op constructions
22+
* \file nn/instance_norm.h
23+
*/
24+
#ifndef TVM_TOPI_NN_INSTANCE_NORM_H_
25+
#define TVM_TOPI_NN_INSTANCE_NORM_H_
26+
27+
#include <tvm/te/operation.h>
28+
#include <tvm/topi/nn/layer_norm.h>
29+
#include <tvm/topi/tags.h>
30+
31+
#include <string>
32+
33+
namespace tvm {
34+
namespace topi {
35+
namespace nn {
36+
37+
using namespace tvm::te;
38+
39+
/*!
40+
* \brief Instance normalization.
41+
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
42+
* \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
43+
* d_{axis_k} == r_k
44+
* \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
45+
* d_{axis_k} == r_k
46+
* \param axis The axis to normalize over (the axis along which mean and variance are
47+
* computed).
48+
* \param epsilon The epsilon value to avoid division by zero.
49+
* \param name The name of the operation.
50+
* \param tag The tag to mark the operation.
51+
* \return The normalized tensor, with the same shape as data.
52+
*/
53+
inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
54+
const Array<Integer>& axis, double epsilon,
55+
std::string name = "T_instance_norm", std::string tag = kInjective) {
56+
return layer_norm(data, gamma, beta, axis, epsilon, name, tag);
57+
}
58+
59+
} // namespace nn
60+
} // namespace topi
61+
} // namespace tvm
62+
63+
#endif // TVM_TOPI_NN_INSTANCE_NORM_H_

python/tvm/topi/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .bnn import *
3939
from .qnn import *
4040
from .upsampling import *
41+
from .instance_norm import instance_norm
4142
from .layer_norm import layer_norm
4243
from .group_norm import group_norm
4344
from .local_response_norm import *
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Instance normalization operator."""
18+
from .. import cpp
19+
20+
21+
def instance_norm(data, gamma, beta, axis, epsilon=1e-5):
22+
"""Instance normalization operator.
23+
24+
Parameters
25+
----------
26+
data : tvm.te.Tensor
27+
N-D with shape (d_0, d_1, ..., d_{N-1})
28+
29+
gamma: tvm.te.Tensor
30+
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
31+
32+
beta: tvm.te.Tensor
33+
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
34+
35+
axis : list of int
36+
Axis over the normalization applied (the axis along which the mean and variance are
37+
computed)
38+
39+
epsilon : float
40+
The epsilon value to avoid division by zero.
41+
42+
Returns
43+
-------
44+
result : tvm.te.Tensor
45+
N-D with shape (d_0, d_1, ..., d_{N-1})
46+
"""
47+
return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon)

python/tvm/topi/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .reorg_python import reorg_python
4444
from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python
4545
from .roi_pool_python import roi_pool_nchw_python
46+
from .instance_norm_python import instance_norm_python
4647
from .layer_norm_python import layer_norm_python
4748
from .group_norm_python import group_norm_python
4849
from .lrn_python import lrn_python
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
18+
"""Instance normalization in python"""
19+
import numpy as np
20+
21+
22+
def instance_norm_python(data, gamma, beta, axis, epsilon=1e-5):
23+
"""Instance normalization operator in Python.
24+
25+
Parameters
26+
----------
27+
data : numpy.ndarray
28+
N-D with shape (d_0, d_1, ..., d_{N-1})
29+
30+
gamma: numpy.ndarray
31+
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
32+
33+
beta: numpy.ndarray
34+
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
35+
36+
axis : int or tuple of ints
37+
Axis over the normalization applied
38+
39+
epsilon : float
40+
The epsilon value to avoid division by zero.
41+
42+
Returns
43+
-------
44+
result : np.ndarray
45+
N-D with shape (d_0, d_1, ..., d_{N-1})
46+
"""
47+
mean = np.mean(data, axis, keepdims=True)
48+
var = np.var(data, axis, keepdims=True)
49+
result = (data - mean) / np.sqrt(var + epsilon)
50+
result *= gamma
51+
if beta is not None:
52+
result += beta
53+
return result

src/topi/nn.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <tvm/topi/nn/dilate.h>
3131
#include <tvm/topi/nn/flatten.h>
3232
#include <tvm/topi/nn/group_norm.h>
33+
#include <tvm/topi/nn/instance_norm.h>
3334
#include <tvm/topi/nn/layer_norm.h>
3435
#include <tvm/topi/nn/local_response_norm.h>
3536
#include <tvm/topi/nn/mapping.h>
@@ -170,5 +171,10 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, TVMRetValue*
170171
static_cast<int>(args[4]), args[5], static_cast<double>(args[6]));
171172
});
172173

174+
/* Ops from nn/instance_norm.h */
175+
TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
176+
*rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
177+
});
178+
173179
} // namespace topi
174180
} // namespace tvm
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test code for instance_norm."""
18+
import numpy as np
19+
import pytest
20+
import tvm
21+
from tvm import te
22+
from tvm import topi
23+
from tvm.topi.utils import get_const_tuple
24+
import tvm.topi.testing
25+
26+
import tvm.testing
27+
28+
29+
_instance_norm_schedule = {
30+
"generic": topi.generic.schedule_injective,
31+
}
32+
33+
34+
# only test on llvm because schedule is missing
35+
@tvm.testing.parametrize_targets("llvm")
36+
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))])
37+
def test_instance_norm(
38+
target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5
39+
):
40+
data = te.placeholder(shape, dtype=dtype, name="data")
41+
scale_shape = [shape[dim] for dim in axis]
42+
gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma")
43+
beta = te.placeholder(scale_shape, dtype=dtype, name="beta")
44+
B = topi.nn.instance_norm(data, gamma, beta, axis, episilon)
45+
46+
data_np = np.random.uniform(size=shape).astype(dtype)
47+
gamma_np = np.random.uniform(size=scale_shape).astype(dtype)
48+
beta_np = np.random.uniform(size=scale_shape).astype(dtype)
49+
b_np = tvm.topi.testing.instance_norm_python(data_np, gamma_np, beta_np, axis, episilon)
50+
51+
with tvm.target.Target(target):
52+
s_func = tvm.topi.testing.dispatch(target, _instance_norm_schedule)
53+
s = s_func([B])
54+
data_tvm = tvm.nd.array(data_np, dev)
55+
gamma_tvm = tvm.nd.array(gamma_np, dev)
56+
beta_tvm = tvm.nd.array(beta_np, dev)
57+
b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
58+
f = tvm.build(s, [data, gamma, beta, B], target)
59+
f(data_tvm, gamma_tvm, beta_tvm, b_tvm)
60+
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
61+
62+
63+
if __name__ == "__main__":
64+
tvm.testing.main()

0 commit comments

Comments
 (0)