Skip to content

Commit 3eac43b

Browse files
committed
apply code review suggestions
1 parent 80fca1d commit 3eac43b

File tree

5 files changed

+34
-25
lines changed

5 files changed

+34
-25
lines changed

include/tvm/topi/nn/rms_norm.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,24 @@ using namespace tvm::te;
4141
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
4242
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
4343
* d_{axis_k} == r_k
44+
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
45+
* d_{axis_k} == r_k
4446
* \param axis The axis to normalize over.
4547
* \param epsilon The epsilon value to avoid division by zero.
4648
* \param name The name of the operation.
4749
* \param tag The tag to mark the operation.
4850
* \return The normalized tensor, with the same shape as data.
4951
*/
50-
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Integer>& axis,
51-
double epsilon, std::string name = "T_rms_norm",
52+
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
53+
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
5254
std::string tag = kInjective) {
5355
const auto& data_type = data->dtype;
5456
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
5557
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
56-
ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
57-
<< "rms_norm: only support float32 and float16 for now";
58-
bool is_float16 = data_type == DataType::Float(16);
58+
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
59+
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";
5960

60-
auto x = is_float16 ? cast(data, DataType::Float(32)) : data;
61-
auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight;
62-
auto square = multiply(x, x);
61+
auto square = multiply(data, data);
6362
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
6463

6564
auto ndim = data->shape.size();
@@ -79,12 +78,15 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Int
7978
}
8079
}
8180
auto output =
82-
x(indices) * w(reduce_indices) *
81+
data(indices) * weight(reduce_indices) *
8382
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
83+
if (bias.defined()) {
84+
output += bias(reduce_indices);
85+
}
8486
return output;
8587
};
8688
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
87-
return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm;
89+
return rms_norm;
8890
}
8991

9092
} // namespace nn

python/tvm/topi/nn/rms_norm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
from .. import cpp
1919

2020

21-
def rms_norm(data, weight, axis, epsilon=1e-5):
22-
"""Root mean square normalization operator.
23-
It accepts fp16 and fp32 as input data type. It will cast the input to fp32
24-
to perform the computation. The output will have the same data type as input.
21+
def rms_norm(data, weight, bias, axis, epsilon=1e-5):
22+
"""Root mean square normalization operator. The output will have the same data type as input.
2523
2624
Parameters
2725
----------
@@ -31,6 +29,9 @@ def rms_norm(data, weight, axis, epsilon=1e-5):
3129
weight: tvm.te.Tensor
3230
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
3331
32+
bias: tvm.te.Tensor
33+
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
34+
3435
axis : list of int
3536
Axis over the normalization applied
3637
@@ -42,4 +43,4 @@ def rms_norm(data, weight, axis, epsilon=1e-5):
4243
result : tvm.te.Tensor
4344
N-D with shape (d_0, d_1, ..., d_{N-1})
4445
"""
45-
return cpp.nn.rms_norm(data, weight, axis, epsilon)
46+
return cpp.nn.rms_norm(data, weight, bias, axis, epsilon)

python/tvm/topi/testing/rms_norm_python.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121

22-
def rms_norm_python(data, weight, axis, epsilon=1e-5):
22+
def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
2323
"""Root mean square normalization operator in Python.
2424
2525
Parameters
@@ -30,6 +30,9 @@ def rms_norm_python(data, weight, axis, epsilon=1e-5):
3030
weight: numpy.ndarray
3131
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
3232
33+
bias: numpy.ndarray
34+
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
35+
3336
axis : int or tuple of ints
3437
Axis over the normalization applied
3538
@@ -41,8 +44,8 @@ def rms_norm_python(data, weight, axis, epsilon=1e-5):
4144
result : np.ndarray
4245
N-D with shape (d_0, d_1, ..., d_{N-1})
4346
"""
44-
old_dtype = data.dtype
45-
data = data.astype("float32")
4647
square_mean = np.mean(np.square(data), axis, keepdims=True)
4748
result = data * weight / np.sqrt(square_mean + epsilon)
48-
return result.astype(old_dtype)
49+
if bias is not None:
50+
result += bias
51+
return result

src/topi/nn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
179179

180180
/* Ops from nn/rms_norm.h */
181181
TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
182-
*rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
182+
*rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
183183
});
184184

185185
} // namespace topi

tests/python/topi/python/test_topi_rms_norm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,27 @@
3535
@tvm.testing.parametrize_targets("llvm")
3636
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))])
3737
@pytest.mark.parametrize("dtype", ["float32", "float16"])
38-
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4, atol=5e-4):
38+
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4):
3939
data = te.placeholder(shape, dtype=dtype, name="data")
4040
scale_shape = [shape[dim] for dim in axis]
4141
weight = te.placeholder(scale_shape, dtype=dtype, name="weight")
42-
B = topi.nn.rms_norm(data, weight, axis, episilon)
42+
bias = te.placeholder(scale_shape, dtype=dtype, name="weight")
43+
B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
4344

4445
data_np = np.random.uniform(size=shape).astype(dtype)
4546
weight_np = np.random.uniform(size=scale_shape).astype(dtype)
46-
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)
47+
bias_np = np.random.uniform(size=scale_shape).astype(dtype)
48+
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon)
4749

4850
with tvm.target.Target(target):
4951
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
5052
s = s_func([B])
5153
data_tvm = tvm.nd.array(data_np, dev)
5254
weight_tvm = tvm.nd.array(weight_np, dev)
55+
bias_tvm = tvm.nd.array(bias_np, dev)
5356
b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
54-
f = tvm.build(s, [data, weight, B], target)
55-
f(data_tvm, weight_tvm, b_tvm)
57+
f = tvm.build(s, [data, weight, bias, B], target)
58+
f(data_tvm, weight_tvm, bias_tvm, b_tvm)
5659
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
5760

5861

0 commit comments

Comments
 (0)