Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief root mean square normalization op constructions
* \file nn/rms_norm.h
*/
#ifndef TVM_TOPI_NN_RMS_NORM_H_
#define TVM_TOPI_NN_RMS_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/reduction.h>
#include <tvm/topi/tags.h>

#include <string>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

/*!
* \brief Root mean square normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
std::string tag = kInjective) {
const auto& data_type = data->dtype;
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";

auto square = multiply(data, data);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);

auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_extent = make_const(data->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}
auto output =
data(indices) * weight(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
if (bias.defined()) {
output += bias(reduce_indices);
}
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
return rms_norm;
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_RMS_NORM_H_
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .instance_norm import instance_norm
from .layer_norm import layer_norm
from .group_norm import group_norm
from .rms_norm import rms_norm
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
Expand Down
46 changes: 46 additions & 0 deletions python/tvm/topi/nn/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Root mean square normalization operator."""
from .. import cpp


def rms_norm(data, weight, bias, axis, epsilon=1e-5):
"""Root mean square normalization operator. The output will have the same data type as input.

Parameters
----------
data : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})

weight: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

bias: tvm.te.Tensor
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : list of int
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.rms_norm(data, weight, bias, axis, epsilon)
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .instance_norm_python import instance_norm_python
from .layer_norm_python import layer_norm_python
from .group_norm_python import group_norm_python
from .rms_norm_python import rms_norm_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/topi/testing/rms_norm_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Root mean square normalization in python"""
import numpy as np


def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
"""Root mean square normalization operator in Python.

Parameters
----------
data : numpy.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})

weight: numpy.ndarray
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

bias: numpy.ndarray
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : int or tuple of ints
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
square_mean = np.mean(np.square(data), axis, keepdims=True)
result = data * weight / np.sqrt(square_mean + epsilon)
if bias is not None:
result += bias
return result
6 changes: 6 additions & 0 deletions src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <tvm/topi/nn/local_response_norm.h>
#include <tvm/topi/nn/mapping.h>
#include <tvm/topi/nn/pooling.h>
#include <tvm/topi/nn/rms_norm.h>
#include <tvm/topi/nn/softmax.h>

namespace tvm {
Expand Down Expand Up @@ -176,5 +177,10 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
*rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
});

/* Ops from nn/rms_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
});

} // namespace topi
} // namespace tvm
68 changes: 68 additions & 0 deletions tests/python/topi/python/test_topi_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for rms_norm."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
from tvm.topi.utils import get_const_tuple
import tvm.topi.testing

import tvm.testing


_rms_norm_schedule = {
"generic": topi.generic.schedule_injective,
}


# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize(
"shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))]
)
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4):
shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape]
scale_shape_te = [shape_te[dim] for dim in axis]
data = te.placeholder(shape_te, dtype=dtype, name="data")
weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
B = topi.nn.rms_norm(data, weight, bias, axis, episilon)

shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
scale_shape_np = [shape_np[dim] for dim in axis]
data_np = np.random.uniform(size=shape_np).astype(dtype)
weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon)

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
weight_tvm = tvm.nd.array(weight_np, dev)
bias_tvm = tvm.nd.array(bias_np, dev)
b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
f = tvm.build(s, [data, weight, bias, B], target)
f(data_tvm, weight_tvm, bias_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)


if __name__ == "__main__":
tvm.testing.main()