From 80fca1d29da2df37c0766a2146fc7d1f98f79ef4 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 17 Jul 2023 21:17:34 +0000 Subject: [PATCH 1/3] [OP] Add `rms_norm` into TOPI This PR introduces the operator root mean square, `rms_norm`, into TOPI. --- include/tvm/topi/nn/rms_norm.h | 94 +++++++++++++++++++ python/tvm/topi/nn/__init__.py | 1 + python/tvm/topi/nn/rms_norm.py | 45 +++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/rms_norm_python.py | 48 ++++++++++ src/topi/nn.cc | 6 ++ .../python/topi/python/test_topi_rms_norm.py | 60 ++++++++++++ 7 files changed, 255 insertions(+) create mode 100644 include/tvm/topi/nn/rms_norm.h create mode 100644 python/tvm/topi/nn/rms_norm.py create mode 100644 python/tvm/topi/testing/rms_norm_python.py create mode 100644 tests/python/topi/python/test_topi_rms_norm.py diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h new file mode 100644 index 000000000000..e743205611c3 --- /dev/null +++ b/include/tvm/topi/nn/rms_norm.h @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief root mean square normalization op constructions + * \file nn/rms_norm.h + */ +#ifndef TVM_TOPI_NN_RMS_NORM_H_ +#define TVM_TOPI_NN_RMS_NORM_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace topi { +namespace nn { + +using namespace tvm::te; + +/*! + * \brief Root mean square normalization. + * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] + * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and + * d_{axis_k} == r_k + * \param axis The axis to normalize over. + * \param epsilon The epsilon value to avoid division by zero. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * \return The normalized tensor, with the same shape as data. + */ +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& axis, + double epsilon, std::string name = "T_rms_norm", + std::string tag = kInjective) { + const auto& data_type = data->dtype; + const auto& weight_type = weight.defined() ? weight->dtype : data_type; + ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; + ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + << "rms_norm: only support float32 and float16 for now"; + bool is_float16 = data_type == DataType::Float(16); + + auto x = is_float16 ? cast(data, DataType::Float(32)) : data; + auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight; + auto square = multiply(x, x); + auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); + + auto ndim = data->shape.size(); + ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + auto real_axis = GetRealAxis(static_cast(ndim), axis); + auto reduce_extent = make_const(data->dtype, 1); + for (int i : real_axis) { + reduce_extent *= data->shape[i]; + } + auto rms_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices; + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + auto output = + x(indices) * w(reduce_indices) * + tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + return output; + }; + auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag); + return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm; +} + +} // namespace nn +} // namespace topi +} // namespace tvm + +#endif // TVM_TOPI_NN_RMS_NORM_H_ diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index d65c5c45c7e0..2c549cc5b9cf 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -41,6 +41,7 @@ from .instance_norm import instance_norm from .layer_norm import layer_norm from .group_norm import group_norm +from .rms_norm import rms_norm from .local_response_norm import * from .bitserial_conv2d import * from .bitserial_dense import * diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py new file mode 100644 index 000000000000..651ff361bfb9 --- /dev/null +++ b/python/tvm/topi/nn/rms_norm.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Root mean square normalization operator.""" +from .. import cpp + + +def rms_norm(data, weight, axis, epsilon=1e-5): + """Root mean square normalization operator. + It accepts fp16 and fp32 as input data type. It will cast the input to fp32 + to perform the computation. The output will have the same data type as input. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + + weight: tvm.te.Tensor + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + axis : list of int + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + return cpp.nn.rms_norm(data, weight, axis, epsilon) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d950a20c0559..093f84d99bd3 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -46,6 +46,7 @@ from .instance_norm_python import instance_norm_python from .layer_norm_python import layer_norm_python from .group_norm_python import group_norm_python +from .rms_norm_python import rms_norm_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_python import gather_python diff --git a/python/tvm/topi/testing/rms_norm_python.py b/python/tvm/topi/testing/rms_norm_python.py new file mode 100644 index 000000000000..0273b419413c --- /dev/null +++ b/python/tvm/topi/testing/rms_norm_python.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Root mean square normalization in python""" +import numpy as np + + +def rms_norm_python(data, weight, axis, epsilon=1e-5): + """Root mean square normalization operator in Python. + + Parameters + ---------- + data : numpy.ndarray + N-D with shape (d_0, d_1, ..., d_{N-1}) + + weight: numpy.ndarray + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + axis : int or tuple of ints + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : np.ndarray + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + old_dtype = data.dtype + data = data.astype("float32") + square_mean = np.mean(np.square(data), axis, keepdims=True) + result = data * weight / np.sqrt(square_mean + epsilon) + return result.astype(old_dtype) diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 58b962da6afa..9ce329b20637 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace tvm { @@ -176,5 +177,10 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal *rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); }); +/* Ops from nn/rms_norm.h */ +TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::rms_norm(args[0], args[1], args[2], static_cast(args[3])); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/topi/python/test_topi_rms_norm.py b/tests/python/topi/python/test_topi_rms_norm.py new file mode 100644 index 000000000000..a30c5bbc97f8 --- /dev/null +++ b/tests/python/topi/python/test_topi_rms_norm.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for rms_norm.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +from tvm.topi.utils import get_const_tuple +import tvm.topi.testing + +import tvm.testing + + +_rms_norm_schedule = { + "generic": topi.generic.schedule_injective, +} + + +# only test on llvm because schedule is missing +@tvm.testing.parametrize_targets("llvm") +@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4, atol=5e-4): + data = te.placeholder(shape, dtype=dtype, name="data") + scale_shape = [shape[dim] for dim in axis] + weight = te.placeholder(scale_shape, dtype=dtype, name="weight") + B = topi.nn.rms_norm(data, weight, axis, episilon) + + data_np = np.random.uniform(size=shape).astype(dtype) + weight_np = np.random.uniform(size=scale_shape).astype(dtype) + b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon) + + with tvm.target.Target(target): + s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule) + s = s_func([B]) + data_tvm = tvm.nd.array(data_np, dev) + weight_tvm = tvm.nd.array(weight_np, dev) + b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) + f = tvm.build(s, [data, weight, B], target) + f(data_tvm, weight_tvm, b_tvm) + tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + tvm.testing.main() From 3eac43ba6dba2b8c83896464412cea696b62e12f Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 17 Jul 2023 22:48:16 +0000 Subject: [PATCH 2/3] apply code review suggestions --- include/tvm/topi/nn/rms_norm.h | 22 ++++++++++--------- python/tvm/topi/nn/rms_norm.py | 11 +++++----- python/tvm/topi/testing/rms_norm_python.py | 11 ++++++---- src/topi/nn.cc | 2 +- .../python/topi/python/test_topi_rms_norm.py | 13 ++++++----- 5 files changed, 34 insertions(+), 25 deletions(-) diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index e743205611c3..44d38bae6d7a 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -41,25 +41,24 @@ using namespace tvm::te; * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and * d_{axis_k} == r_k + * \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where + * d_{axis_k} == r_k * \param axis The axis to normalize over. * \param epsilon The epsilon value to avoid division by zero. * \param name The name of the operation. * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ -inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& axis, - double epsilon, std::string name = "T_rms_norm", +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias, + const Array& axis, double epsilon, std::string name = "T_rms_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& weight_type = weight.defined() ? weight->dtype : data_type; ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; - ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) - << "rms_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + const auto& bias_type = bias.defined() ? bias->dtype : data_type; + ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type"; - auto x = is_float16 ? cast(data, DataType::Float(32)) : data; - auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight; - auto square = multiply(x, x); + auto square = multiply(data, data); auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); auto ndim = data->shape.size(); @@ -79,12 +78,15 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arrayshape, rms_norm_func, name, tag); - return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm; + return rms_norm; } } // namespace nn diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py index 651ff361bfb9..f2f5a7e67487 100644 --- a/python/tvm/topi/nn/rms_norm.py +++ b/python/tvm/topi/nn/rms_norm.py @@ -18,10 +18,8 @@ from .. import cpp -def rms_norm(data, weight, axis, epsilon=1e-5): - """Root mean square normalization operator. - It accepts fp16 and fp32 as input data type. It will cast the input to fp32 - to perform the computation. The output will have the same data type as input. +def rms_norm(data, weight, bias, axis, epsilon=1e-5): + """Root mean square normalization operator. The output will have the same data type as input. Parameters ---------- @@ -31,6 +29,9 @@ def rms_norm(data, weight, axis, epsilon=1e-5): weight: tvm.te.Tensor K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + bias: tvm.te.Tensor + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + axis : list of int Axis over the normalization applied @@ -42,4 +43,4 @@ def rms_norm(data, weight, axis, epsilon=1e-5): result : tvm.te.Tensor N-D with shape (d_0, d_1, ..., d_{N-1}) """ - return cpp.nn.rms_norm(data, weight, axis, epsilon) + return cpp.nn.rms_norm(data, weight, bias, axis, epsilon) diff --git a/python/tvm/topi/testing/rms_norm_python.py b/python/tvm/topi/testing/rms_norm_python.py index 0273b419413c..7fad5d57ce10 100644 --- a/python/tvm/topi/testing/rms_norm_python.py +++ b/python/tvm/topi/testing/rms_norm_python.py @@ -19,7 +19,7 @@ import numpy as np -def rms_norm_python(data, weight, axis, epsilon=1e-5): +def rms_norm_python(data, weight, bias, axis, epsilon=1e-5): """Root mean square normalization operator in Python. Parameters @@ -30,6 +30,9 @@ def rms_norm_python(data, weight, axis, epsilon=1e-5): weight: numpy.ndarray K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + bias: numpy.ndarray + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + axis : int or tuple of ints Axis over the normalization applied @@ -41,8 +44,8 @@ def rms_norm_python(data, weight, axis, epsilon=1e-5): result : np.ndarray N-D with shape (d_0, d_1, ..., d_{N-1}) """ - old_dtype = data.dtype - data = data.astype("float32") square_mean = np.mean(np.square(data), axis, keepdims=True) result = data * weight / np.sqrt(square_mean + epsilon) - return result.astype(old_dtype) + if bias is not None: + result += bias + return result diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 9ce329b20637..ba88f01c6850 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -179,7 +179,7 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal /* Ops from nn/rms_norm.h */ TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::rms_norm(args[0], args[1], args[2], static_cast(args[3])); + *rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); }); } // namespace topi diff --git a/tests/python/topi/python/test_topi_rms_norm.py b/tests/python/topi/python/test_topi_rms_norm.py index a30c5bbc97f8..c94d28b5e83d 100644 --- a/tests/python/topi/python/test_topi_rms_norm.py +++ b/tests/python/topi/python/test_topi_rms_norm.py @@ -35,24 +35,27 @@ @tvm.testing.parametrize_targets("llvm") @pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) -def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4, atol=5e-4): +def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4): data = te.placeholder(shape, dtype=dtype, name="data") scale_shape = [shape[dim] for dim in axis] weight = te.placeholder(scale_shape, dtype=dtype, name="weight") - B = topi.nn.rms_norm(data, weight, axis, episilon) + bias = te.placeholder(scale_shape, dtype=dtype, name="weight") + B = topi.nn.rms_norm(data, weight, bias, axis, episilon) data_np = np.random.uniform(size=shape).astype(dtype) weight_np = np.random.uniform(size=scale_shape).astype(dtype) - b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon) + bias_np = np.random.uniform(size=scale_shape).astype(dtype) + b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon) with tvm.target.Target(target): s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule) s = s_func([B]) data_tvm = tvm.nd.array(data_np, dev) weight_tvm = tvm.nd.array(weight_np, dev) + bias_tvm = tvm.nd.array(bias_np, dev) b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) - f = tvm.build(s, [data, weight, B], target) - f(data_tvm, weight_tvm, b_tvm) + f = tvm.build(s, [data, weight, bias, B], target) + f(data_tvm, weight_tvm, bias_tvm, b_tvm) tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol) From 3d52ff3c6a9341a317da2972ed90650d7f73c4d0 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 17 Jul 2023 23:03:09 +0000 Subject: [PATCH 3/3] add symbolic shape testcase --- .../python/topi/python/test_topi_rms_norm.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/python/topi/python/test_topi_rms_norm.py b/tests/python/topi/python/test_topi_rms_norm.py index c94d28b5e83d..35a1485afa6b 100644 --- a/tests/python/topi/python/test_topi_rms_norm.py +++ b/tests/python/topi/python/test_topi_rms_norm.py @@ -33,18 +33,23 @@ # only test on llvm because schedule is missing @tvm.testing.parametrize_targets("llvm") -@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) +@pytest.mark.parametrize( + "shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))] +) @pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4): - data = te.placeholder(shape, dtype=dtype, name="data") - scale_shape = [shape[dim] for dim in axis] - weight = te.placeholder(scale_shape, dtype=dtype, name="weight") - bias = te.placeholder(scale_shape, dtype=dtype, name="weight") + shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape] + scale_shape_te = [shape_te[dim] for dim in axis] + data = te.placeholder(shape_te, dtype=dtype, name="data") + weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight") + bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight") B = topi.nn.rms_norm(data, weight, bias, axis, episilon) - data_np = np.random.uniform(size=shape).astype(dtype) - weight_np = np.random.uniform(size=scale_shape).astype(dtype) - bias_np = np.random.uniform(size=scale_shape).astype(dtype) + shape_np = [v[1] if isinstance(v, tuple) else v for v in shape] + scale_shape_np = [shape_np[dim] for dim in axis] + data_np = np.random.uniform(size=shape_np).astype(dtype) + weight_np = np.random.uniform(size=scale_shape_np).astype(dtype) + bias_np = np.random.uniform(size=scale_shape_np).astype(dtype) b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon) with tvm.target.Target(target): @@ -53,7 +58,7 @@ def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, ato data_tvm = tvm.nd.array(data_np, dev) weight_tvm = tvm.nd.array(weight_np, dev) bias_tvm = tvm.nd.array(bias_np, dev) - b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) + b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev) f = tvm.build(s, [data, weight, bias, B], target) f(data_tvm, weight_tvm, bias_tvm, b_tvm) tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)