Skip to content

Commit faad5ac

Browse files
committed
Vectorize RMS norm variance using vectorize_read_with_alignment
Signed-off-by: Benji Beck <[email protected]>
1 parent 41f3884 commit faad5ac

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

csrc/layernorm_kernels.cu

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "dispatch_utils.h"
33
#include "cub_helpers.h"
44
#include "core/batch_invariant.hpp"
5+
#include "quantization/vectorization_utils.cuh"
56

67
#include <torch/cuda.h>
78
#include <c10/cuda/CUDAGuard.h>
@@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
1819
const float epsilon, const int num_tokens, const int hidden_size) {
1920
__shared__ float s_variance;
2021
float variance = 0.0f;
22+
const scalar_t* input_row = input + blockIdx.x * input_stride;
2123

22-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
23-
const float x = (float)input[blockIdx.x * input_stride + idx];
24+
constexpr int VEC_SIZE = 8;
25+
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
26+
#pragma unroll
27+
for (int i = 0; i < VEC_SIZE; ++i) {
28+
float x = static_cast<float>(vec.val[i]);
29+
variance += x * x;
30+
}
31+
};
32+
auto scalar_op = [&variance](const scalar_t& val) {
33+
float x = static_cast<float>(val);
2434
variance += x * x;
25-
}
35+
};
36+
vllm::vectorize_read_with_alignment<VEC_SIZE>(
37+
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
2638

2739
using BlockReduce = cub::BlockReduce<float, 1024>;
2840
__shared__ typename BlockReduce::TempStorage reduceStore;

csrc/layernorm_quant_kernels.cu

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "dispatch_utils.h"
1111
#include "cub_helpers.h"
1212
#include "core/batch_invariant.hpp"
13+
#include "quantization/vectorization_utils.cuh"
1314

1415
#include <torch/cuda.h>
1516
#include <c10/cuda/CUDAGuard.h>
@@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
2829
__shared__ float s_variance;
2930
float variance = 0.0f;
3031

31-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
32-
const float x = (float)input[blockIdx.x * input_stride + idx];
32+
const scalar_t* input_row = input + blockIdx.x * input_stride;
33+
34+
constexpr int VEC_SIZE = 8;
35+
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
36+
#pragma unroll
37+
for (int i = 0; i < VEC_SIZE; ++i) {
38+
float x = static_cast<float>(vec.val[i]);
39+
variance += x * x;
40+
}
41+
};
42+
auto scalar_op = [&variance](const scalar_t& val) {
43+
float x = static_cast<float>(val);
3344
variance += x * x;
34-
}
45+
};
46+
vllm::vectorize_read_with_alignment<VEC_SIZE>(
47+
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
3548

3649
using BlockReduce = cub::BlockReduce<float, 1024>;
3750
__shared__ typename BlockReduce::TempStorage reduceStore;

0 commit comments

Comments
 (0)