diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index 25bac50fc..40d367c2d 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -21,7 +21,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_local[i, j] += A_shared[i, j] * A_shared[i, j] T.reduce_sum(A_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for k in range(num_k_step): # reverse, better cache hit rate @@ -51,7 +51,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] T.reduce_sum(A_pow_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :]) diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py index 8cc413531..a05f9b082 100644 --- a/examples/norm/test_rms_norm.py +++ b/examples/norm/test_rms_norm.py @@ -22,7 +22,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_local[i, j] += A_shared[i, j] * A_shared[i, j] T.reduce_sum(A_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for k in range(num_k_step): # reverse, better cache hit rate @@ -51,7 +51,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): A_pow_local[i, j] = A_local[i, j] * A_local[i, j] T.reduce_sum(A_pow_local, A_powsum, dim=1) for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) for i, j in T.Parallel(blk_m, N): A_local[i, j] *= A_powsum[i] T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])