Skip to content

Commit cb34604

Browse files
[TOPI] Add generic batch norm (#9694)
* Add topi batch norm and tests * Handle none values correctly * Return correct nun outputs for onnx * Use moving var/mean and update tests * Add a test for batch norm folding * Fix comment * Format with black * Re-order test args to match interface * Call fold constant manually
1 parent 5557b8c commit cb34604

File tree

12 files changed

+418
-2
lines changed

12 files changed

+418
-2
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,10 @@ def _impl_v1(cls, inputs, attr, params):
474474
op_name="batch_norm",
475475
ignores=["spatial", "is_test", "consumed_inputs", "momentum", "training_mode"],
476476
)(inputs, attr, params)
477-
return out[0]
477+
# We only support test mode, so we return data, moving_mean, moving_var,
478+
# and then moving_mean and moving_var again as placeholders for
479+
# the expected "saved_mean", "saved_var".
480+
return _expr.TupleWrapper(_expr.Tuple((*out, out[1], out[2])), 5)
478481

479482

480483
class InstanceNorm(OnnxOpConverter):

python/tvm/relay/op/nn/_nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def legalize_batch_matmul(attrs, inputs, types):
152152
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
153153

154154

155+
# batch_norm
156+
reg.register_strategy("nn.batch_norm", strategy.batch_norm_strategy)
157+
reg.register_pattern("nn.batch_norm", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
158+
159+
155160
# sparse_dense
156161
@reg.register_compute("nn.sparse_dense")
157162
def compute_sparse_dense(attrs, inputs, out_type):

python/tvm/relay/op/strategy/generic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,29 @@ def batch_matmul_strategy(attrs, inputs, out_type, target):
848848
return strategy
849849

850850

851+
# batch_norm
852+
def wrap_compute_batch_norm(topi_compute):
853+
"""wrap batch_norm topi compute"""
854+
855+
def _compute_batch_norm(attrs, inputs, out_type):
856+
return topi_compute(*inputs, attrs.axis, attrs.epsilon, attrs.center, attrs.scale)
857+
858+
return _compute_batch_norm
859+
860+
861+
@override_native_generic_func("batch_norm_strategy")
862+
def batch_norm_strategy(attrs, inputs, out_type, target):
863+
"""batch_norm generic strategy"""
864+
logger.warning("batch_norm is not optimized for this platform.")
865+
strategy = _op.OpStrategy()
866+
strategy.add_implementation(
867+
wrap_compute_batch_norm(topi.nn.batch_norm),
868+
wrap_topi_schedule(topi.generic.schedule_batch_norm),
869+
name="batch_norm.generic",
870+
)
871+
return strategy
872+
873+
851874
# sparse dense
852875
def wrap_compute_sparse_dense(topi_compute):
853876
"""wrap sparse dense topi compute"""

python/tvm/topi/generic/nn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,23 @@ def schedule_batch_matmul(outs):
815815
return _default_schedule(outs, False)
816816

817817

818+
def schedule_batch_norm(outs):
819+
"""Schedule for batch_norm
820+
821+
Parameters
822+
----------
823+
outs: Array of Tensor
824+
The computation graph description of sparse_transpose
825+
in the format of an array of tensors.
826+
827+
Returns
828+
-------
829+
sch: Schedule
830+
The computation schedule for the op.
831+
"""
832+
return _default_schedule(outs, False)
833+
834+
818835
def schedule_correlation_nchw(outs):
819836
"""Schedule for correlation_nchw
820837

python/tvm/topi/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .bitserial_conv2d import *
4343
from .bitserial_dense import *
4444
from .batch_matmul import *
45+
from .batch_norm import *
4546
from .sparse import *
4647
from .pad import *
4748
from .fifo_buffer import *

python/tvm/topi/nn/batch_norm.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Batch normalization."""
18+
import typing
19+
20+
from tvm import te
21+
from tvm import topi
22+
23+
24+
def batch_norm(
25+
data: te.Tensor,
26+
gamma: te.Tensor,
27+
beta: te.Tensor,
28+
moving_mean: te.Tensor,
29+
moving_var: te.Tensor,
30+
axis: typing.Optional[int] = None,
31+
epsilon: typing.Optional[float] = None,
32+
center: typing.Optional[bool] = None,
33+
scale: typing.Optional[bool] = None,
34+
) -> typing.List[te.Tensor]:
35+
"""Batch normalization layer (Ioffe and Szegedy, 2014).
36+
37+
Normalizes the input at each batch, i.e. applies a transformation
38+
that maintains the mean activation close to 0 and the activation
39+
standard deviation close to 1.
40+
41+
Parameters
42+
----------
43+
data : tvm.te.Tensor
44+
Input to be batch-normalized.
45+
46+
gamma : tvm.te.Tensor
47+
Scale factor to be applied to the normalized tensor.
48+
49+
beta : tvm.te.Tensor
50+
Offset to be applied to the normalized tensor.
51+
52+
moving_mean : tvm.te.Tensor
53+
Running mean of input.
54+
55+
moving_var : tvm.te.Tensor
56+
Running variance of input.
57+
58+
axis : int, optional, default=1
59+
Specify along which shape axis the normalization should occur.
60+
61+
epsilon : float, optional, default=1e-5
62+
Small float added to variance to avoid dividing by zero.
63+
64+
center : bool, optional, default=True
65+
If True, add offset of beta to normalized tensor, If False,
66+
beta is ignored.
67+
68+
scale : bool, optional, defualt=True
69+
If True, scale normalized tensor by gamma. If False, gamma
70+
is ignored.
71+
72+
Returns
73+
-------
74+
output : list of tvm.te.Tensor
75+
Normalized data with same shape as input
76+
77+
moving_mean : tvm.te.Tensor
78+
Running mean of input.
79+
80+
moving_var : tvm.te.Tensor
81+
Running variance of input.
82+
"""
83+
if axis is None:
84+
axis = 1
85+
86+
if epsilon is None:
87+
epsilon = 1e-5
88+
89+
if center is None:
90+
center = True
91+
92+
if scale is None:
93+
scale = True
94+
95+
shape = [1] * len(data.shape)
96+
shape[axis] = data.shape[axis]
97+
98+
moving_mean_rs = topi.reshape(moving_mean, shape)
99+
moving_var_rs = topi.reshape(moving_var, shape)
100+
101+
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
102+
103+
if scale:
104+
out = out * topi.reshape(gamma, shape)
105+
if center:
106+
out = out + topi.reshape(beta, shape)
107+
108+
# Moving mean and var aren't updated during test. To avoid
109+
# placeholder reuse, we multiply by 1 and return them.
110+
return [out, moving_mean * 1, moving_var * 1]

python/tvm/topi/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from .gather_nd_python import gather_nd_python
5050
from .strided_slice_python import strided_slice_python, strided_set_python
5151
from .batch_matmul import batch_matmul
52+
from .batch_norm import batch_norm
5253
from .slice_axis_python import slice_axis_python
5354
from .sequence_mask_python import sequence_mask
5455
from .poolnd_python import poolnd_python
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Batch Normalization implemented in Numpy."""
18+
import numpy as np
19+
20+
21+
def batch_norm(
22+
x: np.ndarray,
23+
gamma: np.ndarray,
24+
beta: np.ndarray,
25+
moving_mean: np.ndarray,
26+
moving_var: np.ndarray,
27+
axis: int,
28+
epsilon: float,
29+
center: bool,
30+
scale: bool,
31+
):
32+
"""Batch Normalization operator implemented in Numpy.
33+
34+
Parameters
35+
----------
36+
data : np.ndarray
37+
Input to be batch-normalized.
38+
39+
gamma : np.ndarray
40+
Scale factor to be applied to the normalized tensor.
41+
42+
beta : np.ndarray
43+
Offset to be applied to the normalized tensor.
44+
45+
moving_mean : np.ndarray
46+
Running mean of input.
47+
48+
moving_var : np.ndarray
49+
Running variance of input.
50+
51+
axis : int
52+
Specify along which shape axis the normalization should occur.
53+
54+
epsilon : float
55+
Small float added to variance to avoid dividing by zero.
56+
57+
center : bool
58+
If True, add offset of beta to normalized tensor, If False,
59+
beta is ignored.
60+
61+
scale : bool
62+
If True, scale normalized tensor by gamma. If False, gamma
63+
is ignored.
64+
65+
Returns
66+
-------
67+
output : np.ndarray
68+
Normalized data with same shape as input
69+
70+
moving_mean : np.ndarray
71+
Running mean of input.
72+
73+
moving_var : np.ndarray
74+
Running variance of input.
75+
"""
76+
shape = [1] * len(x.shape)
77+
shape[axis] = x.shape[axis]
78+
79+
moving_mean_rs = moving_mean.reshape(shape)
80+
moving_var_rs = moving_var.reshape(shape)
81+
82+
out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)
83+
84+
if scale:
85+
out = out * gamma.reshape(shape)
86+
if center:
87+
out = out + beta.reshape(shape)
88+
89+
return [out, moving_mean, moving_var]

src/relay/op/nn/nn.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,8 @@ bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
745745
reporter->Assign(types[4], TensorType({axis_size}, data->dtype));
746746

747747
// output is a tuple of the normed data (same shape as input), new running mean,
748-
// and new running average (the latter two are both vectors of length dim)
748+
// new running variance, saved mean and saved variance (the latter are all
749+
// vectors of length dim)
749750
std::vector<Type> fields;
750751
auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}), data->dtype);
751752
fields.push_back(TensorType(data->shape, data->dtype));

src/topi/schedule.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense)
230230
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
231231
.set_default(WrapSchedule(topi::generic::default_schedule));
232232

233+
TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm)
234+
.set_default(WrapSchedule(topi::generic::default_schedule));
235+
233236
TVM_REGISTER_GENERIC_FUNC(schedule_pool)
234237
.set_default(WrapSchedule(topi::generic::default_schedule))
235238
.register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))

0 commit comments

Comments
 (0)