Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8748f11
Blockwise quant RMS norm
ElizaWszola Oct 30, 2025
ea9f4db
Cleanup
ElizaWszola Oct 31, 2025
b3a55fd
Apply quant layer norm fixes from #27865, inv scale fix for int8
ElizaWszola Nov 4, 2025
1e912ee
Cleanup
ElizaWszola Nov 4, 2025
051b451
Vectorize
ElizaWszola Nov 4, 2025
2584f2f
Unify kernel shapes to fuse
ElizaWszola Nov 5, 2025
9bc68d4
Fix
ElizaWszola Nov 6, 2025
0fce111
Cleanup
ElizaWszola Nov 6, 2025
0fac68c
Scalar scale computation is working again
ElizaWszola Nov 6, 2025
294e884
Vectorized
ElizaWszola Nov 6, 2025
54ab82f
Test group_size=64, add benchmarks
ElizaWszola Nov 7, 2025
3718bbc
Merge branch 'main' into blockwise-quant-rms-norm
yewentao256 Nov 7, 2025
e00a6d7
optimize
yewentao256 Nov 7, 2025
77e0078
Add fusion patterns
ElizaWszola Nov 14, 2025
3c61951
Merge branch 'blockwise-quant-rms-norm' of https://github.com/neuralm…
ElizaWszola Nov 14, 2025
0d8c405
Merge branch 'main' into blockwise-quant-rms-norm
ElizaWszola Nov 14, 2025
c63bb1b
Account for transposed scales
ElizaWszola Nov 18, 2025
e8c5563
Cleanup fallback code
ElizaWszola Nov 18, 2025
c745e91
Cleanup comments, var names
ElizaWszola Nov 19, 2025
949db4d
Transpose scales if needed
ElizaWszola Nov 21, 2025
e151ea7
Fix redundant write to scales, write to transposed scales too
ElizaWszola Nov 26, 2025
8364887
Fix build
ElizaWszola Nov 26, 2025
88f524d
Merge branch 'main' into blockwise-quant-rms-norm
ElizaWszola Nov 27, 2025
ee2a354
Keep rms in shared memory in kernels
ElizaWszola Dec 1, 2025
e2b82b1
Constexpr group size
ElizaWszola Dec 1, 2025
bf0d3b5
Fix zero division when group_size is 0
ElizaWszola Dec 1, 2025
5caf3a7
A few more improvements
ElizaWszola Dec 3, 2025
92fd8c9
Cleanup unused template
ElizaWszola Dec 3, 2025
377d204
Merge branch 'main' into blockwise-quant-rms-norm
ElizaWszola Dec 4, 2025
06e3645
Feedback, add gs==64 to fusion tests
ElizaWszola Dec 5, 2025
416f173
Move type dispatch to dispatch function
ElizaWszola Dec 5, 2025
b2e2251
Fix deepgemm when we use e8m0
ElizaWszola Dec 5, 2025
f4a206c
Merge branch 'luka/lora-passing' into blockwise-quant-rms-norm
ProExpertProg Dec 7, 2025
990bc65
Merge branch 'main' into blockwise-quant-rms-norm
ProExpertProg Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)


@dataclass
Expand All @@ -22,13 +25,15 @@ class bench_params_t:
hidden_size: int
add_residual: bool
dtype: torch.dtype
group_size: list[int]

def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x R {self.add_residual} "
f"x DT {self.dtype}"
f"x GS {self.group_size}"
)


Expand All @@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]:
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]
GROUP_SIZES = [[1, 64], [1, 128]]

combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
)
return bench_params

Expand All @@ -52,6 +58,7 @@ def unfused_int8_impl(
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
Expand All @@ -69,6 +76,7 @@ def unfused_fp8_impl(
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
Expand All @@ -81,23 +89,63 @@ def unfused_fp8_impl(
torch_out, _ = ops.scaled_fp8_quant(torch_out)


def unfused_groupwise_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)


def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
)


def fused_groupwise_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_per_block_quant(
x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
group_size,
residual=residual,
is_scale_transposed=True,
)


# Bench functions
def bench_fn(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
group_size: list[int],
label: str,
sub_label: str,
fn: Callable,
Expand All @@ -110,10 +158,11 @@ def bench_fn(
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
Expand Down Expand Up @@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
unfused_int8_impl,
Expand All @@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
Expand All @@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
fused_impl,
Expand All @@ -189,13 +241,44 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_fp8_impl",
)
)

# unfused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)

# fused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_groupwise_impl,
"fused_groupwise_fp8_impl",
)
)

print_timers(timers)

return timers
Expand Down
18 changes: 18 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,24 @@
} \
}

#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}

#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}

#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
Expand Down
7 changes: 7 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);

void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales, double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand Down
Loading