diff --git a/.github/workflows/pr-test-xeon.yml b/.github/workflows/pr-test-xeon.yml
index 0fb4721ba173..abaa9401d96d 100644
--- a/.github/workflows/pr-test-xeon.yml
+++ b/.github/workflows/pr-test-xeon.yml
@@ -115,7 +115,7 @@ jobs:
timeout-minutes: 36
run: |
docker exec -w /sglang-checkout/ ci_sglang_xeon \
- bash -c "source /opt/.venv/bin/activate && cd ./test/srt && python3 run_suite.py --suite per-commit-cpu --timeout-per-file 1500"
+ bash -c "source /opt/.venv/bin/activate && cd ./test && python3 run_suite.py --hw cpu --suite stage-b-test-cpu"
- name: Change permission
timeout-minutes: 2
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8118e91c26cf..7fea2c91d535 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -92,6 +92,7 @@ repos:
entry: python3 scripts/ci/check_registered_tests.py
language: system
files: ^test/registered/.*\.py$
+ exclude: ^test/registered/.*/utils\.py$
pass_filenames: false
- id: check-no-docs-changes
name: reject changes under legacy docs/
diff --git a/docs_new/index.mdx b/docs_new/index.mdx
index 6a5b1ed19ddb..908d517386da 100644
--- a/docs_new/index.mdx
+++ b/docs_new/index.mdx
@@ -83,7 +83,7 @@ It is designed to deliver low-latency and high-throughput inference across a wid
}}
>
- {"DeepSeek-V4 on Day 0: From Fast Inference to Verified RL with SGLang and Miles"}
+ {"Updating 1T parameters in seconds \u2014 P2P weight transfer in Large Scale Distributed RL"}
- {"April 25, 2026"} + {"April 29, 2026"}
- {"HiSparse: Turbocharging Sparse Attention with Hierarchical Memory"}
+ {"DeepSeek-V4 on Day 0: From Fast Inference to Verified RL with SGLang and Miles"}
- {"April 10, 2026"} + {"April 25, 2026"}
- {"Highlights of SGLang at NVIDIA GTC 2026"}
+ {"HiSparse: Turbocharging Sparse Attention with Hierarchical Memory"}
- {"March 31, 2026"} + {"April 10, 2026"}
- {"Elastic EP in SGLang: Achieving Partial Failure Tolerance for DeepSeek MoE Deployments"}
+ {"Highlights of SGLang at NVIDIA GTC 2026"}
- {"March 25, 2026"} + {"March 31, 2026"}
- {"ROCm Support for Miles: Large-Scale RL Post-Training on AMD Instinct\u2122 GPUs"}
+ {"Elastic EP in SGLang: Achieving Partial Failure Tolerance for DeepSeek MoE Deployments"}
- {"March 17, 2026"} + {"March 25, 2026"}
- {"March 11, 2026"} + {"March 17, 2026"}
diff --git a/scripts/ci/check_registered_tests.py b/scripts/ci/check_registered_tests.py index 3a9e9b87b242..66def25ff4d2 100755 --- a/scripts/ci/check_registered_tests.py +++ b/scripts/ci/check_registered_tests.py @@ -22,11 +22,11 @@ def main() -> int: ci_register = importlib.util.module_from_spec(spec) spec.loader.exec_module(ci_register) - # Same filter as run_suite.py: skip conftest.py and __init__.py + # Same filter as run_suite.py: skip conftest.py, __init__.py, and utils.py files = sorted( f for f in glob.glob("test/registered/**/*.py", recursive=True) - if os.path.basename(f) not in ("conftest.py", "__init__.py") + if os.path.basename(f) not in ("conftest.py", "__init__.py", "utils.py") ) if not files: return 0 diff --git a/test/registered/cpu/test_activation.py b/test/registered/cpu/test_activation.py new file mode 100644 index 000000000000..fe020c8872f4 --- /dev/null +++ b/test/registered/cpu/test_activation.py @@ -0,0 +1,59 @@ +import itertools +import unittest + +import torch +from utils import GeluAndMul, SiluAndMul, precision + +from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestActivation(CustomTestCase): + M = [128, 129, 257] + N = [22016, 22018] + dtype = [torch.float16, torch.bfloat16] + + def _silu_and_mul_test(self, m, n, dtype): + set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) + + x = torch.randn([m, n], dtype=dtype) + + out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) + ref_out = SiluAndMul(x) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def _gelu_and_mul_test(self, m, n, dtype): + x = torch.randn([m, n], dtype=dtype) + + out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x) + ref_out = GeluAndMul(x, approximate="none") + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def _gelu_tanh_and_mul_test(self, m, n, dtype): + x = torch.randn([m, n], dtype=dtype) + + out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x) + ref_out = GeluAndMul(x, approximate="tanh") + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def test_activation(self): + for params in itertools.product(self.M, self.N, self.dtype): + with self.subTest(m=params[0], n=params[1], dtype=params[2]): + self._silu_and_mul_test(*params) + self._gelu_and_mul_test(*params) + self._gelu_tanh_and_mul_test(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_binding.py b/test/registered/cpu/test_binding.py new file mode 100644 index 000000000000..f623045a74df --- /dev/null +++ b/test/registered/cpu/test_binding.py @@ -0,0 +1,30 @@ +import re +import unittest + +import torch + +kernel = torch.ops.sgl_kernel + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +class TestBinding(CustomTestCase): + def test_binding(self): + start_id = 1 + n_cpu = 6 + + expected_cores = list(map(str, range(start_id, start_id + n_cpu))) + cpu_ids = ",".join(expected_cores) + output = kernel.init_cpu_threads_env(cpu_ids) + + bindings = re.findall(r"OMP tid: \d+, core (\d+)", output) + self.assertEqual(len(bindings), n_cpu) + + self.assertEqual(bindings, expected_cores) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_bmm.py b/test/registered/cpu/test_bmm.py new file mode 100644 index 000000000000..93257b7182e7 --- /dev/null +++ b/test/registered/cpu/test_bmm.py @@ -0,0 +1,98 @@ +import itertools +import unittest + +# TODO: use interface in cpu.py +import torch +import torch.nn as nn +from utils import precision + +from sglang.srt.layers.quantization.fp8_utils import input_to_float8 +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class Mod(nn.Module): + def __init__(self, input_channel, output_channel, has_bias): + super(Mod, self).__init__() + self.linear = torch.nn.Linear(input_channel, output_channel, has_bias) + + def forward(self, x): + return self.linear(x) + + +class TestBmm(CustomTestCase): + M = [1, 2, 11, 111] + N = [128 + 32, 512] + K = [512 + 32, 128 + 32] + B = [1, 16, 17] + chunk = [True, False] + + def _get_bmm_inputs(self, B, M, N, K, chunk, dtype): + if chunk: + mat1 = ( + torch.randn(M, B, K + 64, dtype=dtype).narrow(2, 0, K).transpose_(0, 1) + ) + mat2 = torch.randn(B, N, K, dtype=dtype).transpose_(1, 2) + mat3 = ( + torch.randn(M, B, N + 64, dtype=dtype).narrow(2, 0, N).transpose_(0, 1) + ) + else: + mat1 = torch.randn(M, B, K, dtype=dtype).transpose_(0, 1) + mat2 = torch.randn(B, N, K, dtype=dtype).transpose_(1, 2) + mat3 = torch.randn(M, B, N, dtype=dtype).transpose_(0, 1) + return mat1, mat2, mat3 + + def _bf16_bmm(self, B, M, N, K, chunk, dtype=torch.bfloat16): + mat1, mat2, mat3 = self._get_bmm_inputs(B, M, N, K, chunk, dtype) + ref = torch.bmm(mat1, mat2) + mat2_t = mat2.transpose_(1, 2) + mat3.zero_() + torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, mat2, False, None) + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) + + packed_B = torch.ops.sgl_kernel.convert_weight_packed(mat2_t) + mat3.zero_() + torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, packed_B, True, None) + torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) + + def _fp8_bmm(self, B, M, N, K, chunk, dtype=torch.bfloat16): + mat1, mat2, mat3 = self._get_bmm_inputs(B, M, N, K, chunk, dtype) + mat2_q, mat2_s = input_to_float8(mat2) + ref = torch.bmm(mat1, mat2_q.to(torch.bfloat16)) * mat2_s + mat2_q_t = mat2_q.transpose_(1, 2).contiguous() + mat3.zero_() + atol = rtol = precision[ref.dtype] + torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, mat2_q_t, False, mat2_s) + torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) + + packed_B_q = torch.ops.sgl_kernel.convert_weight_packed(mat2_q_t) + mat3.zero_() + torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, packed_B_q, True, mat2_s) + torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) + + def test_bmm(self): + for params in itertools.product( + self.B, + self.M, + self.N, + self.K, + self.chunk, + ): + with self.subTest( + B=params[0], + M=params[1], + N=params[2], + K=params[3], + chunk=params[4], + ): + self._bf16_bmm(*params) + self._fp8_bmm(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_causal_conv1d.py b/test/registered/cpu/test_causal_conv1d.py new file mode 100644 index 000000000000..f588eafa2bdc --- /dev/null +++ b/test/registered/cpu/test_causal_conv1d.py @@ -0,0 +1,330 @@ +import unittest +from typing import Optional + +import sgl_kernel # noqa: F401 +import torch +import torch.nn.functional as F +from utils import parametrize, precision + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +causal_conv1d_weight_pack = torch.ops.sgl_kernel.causal_conv1d_weight_pack +causal_conv1d_fwd = torch.ops.sgl_kernel.causal_conv1d_fwd_cpu +causal_conv1d_update = torch.ops.sgl_kernel.causal_conv1d_update_cpu + + +torch.manual_seed(1234) + +PAD_SLOT_ID = -1 + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + + x_new = torch.cat([conv_state, x], dim=-1) + conv_state.copy_(x_new[:, :, -state_len:]) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -seqlen: + ] + + out = out.squeeze(-1) + return out if activation is None else F.silu(out) + + +class TestCausalConv1d(CustomTestCase): + activation = "silu" + + @parametrize( + batch=[1, 1024], + dim=[96, 512], + seqlen=[2, 36], + width=[4], + has_bias=[True, False], + has_initial_state=[True, False], + ) + def test_causal_conv1d( + self, + batch, + dim, + seqlen, + width, + has_bias, + has_initial_state, + dtype=torch.bfloat16, + prepack=True, + ): + x = torch.randn(batch, seqlen, dim).to(dtype).transpose_(-1, -2) + weight = torch.randn(dim, width).to(dtype) + bias = torch.randn(dim).to(dtype) if has_bias else None + + if has_initial_state: + initial_states = torch.randn(batch, dim, width - 1, dtype=dtype) + has_initial_state_tensor = torch.ones(batch, dtype=torch.bool) + else: + initial_states = None + has_initial_state_tensor = None + + packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight + + out_ref, final_states_ref = causal_conv1d_ref( + x, + weight, + bias, + initial_states, + return_final_states=has_initial_state, + activation=self.activation, + ) + + out = causal_conv1d_fwd( + x, + packed_weight, + bias, + initial_states, + None, + None, + has_initial_state_tensor, + self.activation in ["silu"], + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close( + final_states_ref, initial_states, atol=atol, rtol=rtol + ) + + @parametrize( + batch=[11], + dim=[96], + max_seqlen=[66], + width=[4], + ) + def test_causal_conv1d_varlen( + self, + batch, + dim, + max_seqlen, + width, + has_bias=False, + dtype=torch.bfloat16, + prepack=False, + ): + total_entries = batch + 3 + + seqlens = torch.randint(1, max_seqlen, (batch + 1,)) + seqlens[0] = 0 + # 1 or 2 must test + seqlens[-2] = 2 + + query_start_loc = torch.cumsum(seqlens, dim=0).to(torch.int32) + + seqlen = query_start_loc[-1].item() + x = torch.randn(seqlen, dim, dtype=dtype).transpose_(-1, -2) + weight = torch.randn(dim, width, dtype=dtype) + bias = torch.randn(dim, dtype=dtype) if has_bias else None + + final_states = torch.randn(total_entries, dim, width - 1, dtype=dtype) + final_states_ref = final_states.clone() + + has_initial_states = torch.randint(0, 2, (batch,), dtype=torch.bool).fill_( + False + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32)[:batch] + + out_ref = [] + out_ref_b = [] + + return_final_states = final_states is not None + splits = torch.split(x, seqlens[1:].tolist(), dim=1) + for i, x_s in enumerate(splits): + out_ref_b.append( + causal_conv1d_ref( + x_s.unsqueeze(0), + weight, + bias, + activation=self.activation, + return_final_states=return_final_states, + final_states_out=( + final_states_ref[state_indices[i]].unsqueeze(0) + if return_final_states + else None + ), + initial_states=( + final_states_ref[state_indices[i]].unsqueeze(0) + if has_initial_states[i] + else None + ), + ) + ) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) + out_ref_tensor = torch.cat(out_ref, dim=0).squeeze(0) + + out = causal_conv1d_fwd( + x, + weight, + bias, + final_states, + query_start_loc, + state_indices, + has_initial_states, + self.activation in ["silu"], + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref_tensor, out, atol=atol, rtol=rtol) + torch.testing.assert_close(final_states_ref, final_states, atol=atol, rtol=rtol) + + @parametrize( + batch=[11], + dim=[32, 64, 96], + width=[4], + ) + def test_causal_conv1d_update( + self, batch, dim, width, has_bias=False, dtype=torch.bfloat16, prepack=True + ): + x = torch.randn(batch, dim).to(dtype) + conv_state = torch.randn(batch, dim, width - 1, dtype=dtype) + weight = torch.randn(dim, width).to(dtype) + bias = torch.randn(dim).to(dtype) if has_bias else None + + packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight + + conv_state_ref = conv_state.clone() + out_ref = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias, activation=self.activation + ) + + cache_seqlens = None + conv_state_indices = None + out = causal_conv1d_update( + x, + conv_state, + packed_weight, + bias, + self.activation in ["silu"], + cache_seqlens, + conv_state_indices, + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close(conv_state_ref, conv_state, atol=atol, rtol=rtol) + + @parametrize( + batch=[7], + dim=[96], + width=[4], + ) + def test_causal_conv1d_update_with_batch_gather( + self, batch, dim, width, has_bias=False, dtype=torch.bfloat16, prepack=True + ): + total_entries = batch + 3 + + x = torch.randn(batch, dim).to(dtype=dtype) + + conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32) + conv_state = torch.randn(total_entries, dim, width - 1, dtype=dtype) + + weight = torch.randn(dim, width).to(dtype=dtype) + bias = torch.randn(dim).to(dtype=dtype) if has_bias else None + conv_state_ref = conv_state[conv_state_indices, :] + + packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight + + out_ref = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias, activation=self.activation + ) + + cache_seqlens = None + out = causal_conv1d_update( + x, + conv_state, + packed_weight, + bias, + self.activation in ["silu"], + cache_seqlens, + conv_state_indices, + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close( + conv_state_ref, conv_state[conv_state_indices, :], atol=atol, rtol=rtol + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_cpu_graph.py b/test/registered/cpu/test_cpu_graph.py new file mode 100644 index 000000000000..37d70df72b24 --- /dev/null +++ b/test/registered/cpu/test_cpu_graph.py @@ -0,0 +1,91 @@ +""" +Usage: +python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu +""" + +import copy +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + intel_amx_benchmark, + is_in_ci, + popen_launch_server, +) + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +class TestCPUGraph(CustomTestCase): + + @intel_amx_benchmark( + extra_args=[ + "--batch-size", + "1", + "--mem-fraction-static", + "0.05", + "--enable-torch-compile", + "--torch-compile-max-bs", + "2", + "--cuda-graph-bs", + "2", + ], + min_throughput=7, + ) + def test_latency_torch_compile_cpu(self): + return DEFAULT_MLA_MODEL_NAME_FOR_TEST + + def test_mmlu_torch_compile_cpu(self): + model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) + env = copy.deepcopy(os.environ) + env["SGLANG_CPU_OMP_THREADS_BIND"] = "all" + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "intel_amx", + "--mem-fraction-static", + "0.05", + "--disable-radix", + "--trust-remote-code", + "--disable-overlap-schedule", + "--enable-torch-compile", + "--cuda-graph-bs", + "2", + "--tp", + f"{n_numa_node}", + ], + env=env, + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + if is_in_ci(): + self.assertGreater(metrics["score"], 0.45) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_decode.py b/test/registered/cpu/test_decode.py new file mode 100644 index 000000000000..316c446bdab7 --- /dev/null +++ b/test/registered/cpu/test_decode.py @@ -0,0 +1,172 @@ +import unittest + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestDecodeAttention(CustomTestCase): + def _run_sdpa_forward_decode( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + seq_len_q = 1 + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out = ( + scaled_dot_product_attention( + per_req_query.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out + start_q, start_kv = end_q, end_kv + + return output + + def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device): + # This represents the number of tokens already in the sequence + seq_len = 1024 + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + logit_cap = 0.0 + num_kv_splits = 8 + enable_gqa = H_Q != H_KV + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device=device) + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) + v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device) + + key = torch.randn(B, H_KV, D, dtype=dtype) + value = torch.randn(B, H_KV, D_V, dtype=dtype) + loc = torch.randint(0, 10, (B,)).to(torch.int64) + + # set kv cache + k_buffer[loc] = key + v_buffer[loc] = value + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) + + req_to_token = ( + torch.arange(total_tokens, device=device) + .reshape(B, seq_len) + .to(torch.int32) + ) + b_req_idx = torch.arange(B, device=device).to(torch.int64) + b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64) + + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device=device, + ) + + # k_buffer, v_buffer, query, key and value supports non-contiguous tensors + k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) + v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) + q = q.transpose(0, 1).contiguous().transpose(0, 1) + key = key.transpose(0, 1).contiguous().transpose(0, 1) + value = value.transpose(0, 1).contiguous().transpose(0, 1) + torch.ops.sgl_kernel.decode_attention_cpu( + q, + k_buffer, + v_buffer, + o, + key, + value, + loc, + attn_logits, + req_to_token, + b_req_idx, + b_seq_len, + sm_scale, + logit_cap, + ) + + self._run_sdpa_forward_decode( + q, + o_grouped, + k_buffer, + v_buffer, + req_to_token, + b_req_idx, + b_seq_len, + scaling=sm_scale, + enable_gqa=enable_gqa, + ) + + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_grouped.flatten(), dim=0 + ) + self.assertGreater(cos_sim.item(), 0.99) + torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6) + + def _test_grouped_decode_attention(self, device="cpu"): + configs = [ + (2, 16, 16, 64, 64), + (2, 16, 1, 16, 16), + (2, 32, 8, 33, 55), + (2, 16, 1, 64, 64), + (2, 64, 1, 13, 13), + (2, 128, 1, 80, 80), + (2, 128, 2, 512, 512), + (1, 16, 1, 576, 512), + (1, 16, 16, 576, 512), + (1, 22, 1, 576, 512), + (1, 40, 8, 128, 128), + ] + + for B, H_Q, H_KV, D, D_V in configs: + for dtype in [torch.bfloat16, torch.float16]: + self._test_grouped_decode_attention_once( + B, H_Q, H_KV, D, D_V, dtype=dtype, device=device + ) + + def test_grouped_decode_attention(self): + self._test_grouped_decode_attention("cpu") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_extend.py b/test/registered/cpu/test_extend.py new file mode 100644 index 000000000000..573531356761 --- /dev/null +++ b/test/registered/cpu/test_extend.py @@ -0,0 +1,225 @@ +import unittest + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestExtendAttention(CustomTestCase): + + def _run_sdpa_forward_extend( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + + assert seq_lens.shape[0] == extend_prefix_lens.shape[0] + assert seq_lens.shape[0] == extend_seq_lens.shape[0] + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + + extend_seq_len_q = extend_seq_lens[seq_idx] + prefill_seq_len_q = extend_prefix_lens[seq_idx] + + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + extend_seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + per_req_query_redudant = torch.empty( + (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]), + dtype=per_req_query.dtype, + device=per_req_query.device, + ) + + per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out_redudant = ( + scaled_dot_product_attention( + per_req_query_redudant.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :] + start_q, start_kv = end_q, end_kv + return output + + def _test_extend_attention_once( + self, + B, + N_CTX, + H_Q, + H_KV, + D, + DV, + mla=False, + *, + b_seq_len_prefix=None, + b_seq_len_extend=None, + ): + dtype = torch.bfloat16 + + if b_seq_len_prefix is None: + b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) + if mla: + b_seq_len_prefix.zero_() + else: + b_seq_len_prefix = torch.as_tensor(b_seq_len_prefix, dtype=torch.int32) + + if b_seq_len_extend is None: + b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) + else: + b_seq_len_extend = torch.as_tensor(b_seq_len_extend, dtype=torch.int32) + + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32) + req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32) + b_start_loc = torch.zeros((B,), dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + for i in range(B): + req_to_tokens[i, : b_seq_len[i]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + + H_BUF = 1 if mla else H_KV + k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype) + v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype) + v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype) + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype) + + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = ( + torch.randn((b_seq_len_extend[i], H_Q, D), dtype=dtype) * 20 + ) + + # q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors + q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1) + k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) + v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) + k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) + v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + b_start_loc_extend = torch.zeros_like(b_seq_len) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + + sm_scale = 1.0 / (D**0.5) + logit_cap = 0.0 + + # handle index type + b_req_idx = b_req_idx.to(torch.int64) + b_seq_len = b_seq_len.to(torch.int64) + + enable_gqa = H_Q != H_KV + o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) + self._run_sdpa_forward_extend( + q_extend, + o_ref, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_prefix, + b_seq_len_extend, + scaling=sm_scale, + enable_gqa=enable_gqa, + causal=True, + ) + + o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) + torch.ops.sgl_kernel.extend_attention_cpu( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_extend, + b_start_loc_extend, + max_len_extend, + sm_scale, + logit_cap, + ) + + torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2) + + def test_extend_attention(self): + for is_mla in [True, False]: + self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla) + self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla) + self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla) + self._test_extend_attention_once(1, 9000, 16, 1, 32, 32, is_mla) + + def test_extend_attention_large_seq_causal_mask(self): + self._test_extend_attention_once( + B=1, + N_CTX=5001, + H_Q=8, + H_KV=2, + D=64, + DV=64, + b_seq_len_prefix=[0], + b_seq_len_extend=[5000], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_flash_attn.py b/test/registered/cpu/test_flash_attn.py new file mode 100644 index 000000000000..f7f47c454750 --- /dev/null +++ b/test/registered/cpu/test_flash_attn.py @@ -0,0 +1,245 @@ +import unittest + +import sgl_kernel # noqa: F401 +import torch +import torch.nn.functional as F +from utils import parametrize, precision + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +flash_attn_varlen_func = torch.ops.sgl_kernel.flash_attn_varlen_func + +torch.manual_seed(1234) + + +def flash_attn_varlen_ref( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + is_causal, + enable_gqa, +): + cu_q = cu_seqlens_q.tolist() + cu_k = cu_seqlens_k.tolist() + batch = len(cu_k) - 1 + + # [T, H, D] -> [1, H, T, D] + q, k, v = [x.unsqueeze(0).transpose(1, 2) for x in [q, k, v]] + + B, H, T, D = q.shape + out = torch.empty(B, H, T, v.size(-1), dtype=q.dtype) + for b in range(batch): + start_q, end_q = cu_q[b], cu_q[b + 1] + start_k, end_k = cu_k[b], cu_k[b + 1] + + out[:, :, start_q:end_q, :] = F.scaled_dot_product_attention( + q[:, :, start_q:end_q, :], + k[:, :, start_k:end_k, :], + v[:, :, start_k:end_k, :], + is_causal=is_causal, + enable_gqa=enable_gqa, + ) + + # [1, H, T, D] -> [T, H, D] + return out.transpose(1, 2).squeeze(0) + + +# faster version ref kernel for non varlen case +def flash_attn_non_varlen_ref( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + is_causal, + enable_gqa, +): + cu_q = cu_seqlens_q.tolist() + cu_k = cu_seqlens_k.tolist() + batch = len(cu_k) - 1 + + B_T, H, D = q.shape + T = B_T // batch + + # [T, H, D] -> [1, H, T, D] + q, k, v = [x.reshape(batch, T, H, D).transpose(1, 2) for x in [q, k, v]] + + out = F.scaled_dot_product_attention( + q, + k, + v, + is_causal=is_causal, + enable_gqa=enable_gqa, + ) + # [B, H, T, D] -> [B * T, H, D] + return out.transpose(1, 2).reshape(batch * T, H, D) + + +class TestFlashAttn(CustomTestCase): + + @parametrize( + batch=[4], + max_seqlen_q=[35, 96], + max_seqlen_k=[35, 96], + num_heads=[16], + num_heads_kv=[16, 2], + head_dim=[32, 48], # test when D is not 32x + head_dim_v=[32], + is_causal=[True, False], + ) + def test_flash_attn_varlen( + self, + batch, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_kv, + head_dim, + head_dim_v, + is_causal, + ): + dtype = torch.bfloat16 + + # random seqlens for k and kv + seqlens_q = torch.randint(1, max_seqlen_q, (batch,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlen_k, (batch,), dtype=torch.int32) + cu_seqlens_q = torch.zeros((batch + 1,), dtype=torch.int32) + cu_seqlens_k = torch.zeros((batch + 1,), dtype=torch.int32) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, 0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, 0) + + sum_seqlen_q = seqlens_q.sum().item() + sum_seqlen_k = seqlens_k.sum().item() + q = torch.randn(sum_seqlen_q, num_heads, head_dim).to(dtype) + k = torch.randn(sum_seqlen_k, num_heads_kv, head_dim).to(dtype) + v = torch.randn(sum_seqlen_k, num_heads_kv, head_dim_v).to(dtype) + + out_ref = flash_attn_varlen_ref( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + is_causal=is_causal, + enable_gqa=num_heads != num_heads_kv, + ) + + out = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqlens_q.max().item(), + seqlens_k.max().item(), + is_causal, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + + # test with large size to capture overflow issue + @parametrize( + batch=[4097], + max_seqlen_q=[4097], + max_seqlen_k=[4097], + num_heads=[4], + num_heads_kv=[4], + head_dim=[32], + head_dim_v=[32], + is_causal=[False], + ) + def test_flash_attn_large_size( + self, + batch, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_kv, + head_dim, + head_dim_v, + is_causal, + ): + dtype = torch.bfloat16 + + # test the non varlen case + seqlens_q = torch.full((batch,), max_seqlen_q, dtype=torch.int32) + seqlens_k = torch.full((batch,), max_seqlen_k, dtype=torch.int32) + + cu_seqlens_q = torch.zeros((batch + 1,), dtype=torch.int32) + cu_seqlens_k = torch.zeros((batch + 1,), dtype=torch.int32) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, 0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, 0) + + sum_seqlen_q = seqlens_q.sum().item() + sum_seqlen_k = seqlens_k.sum().item() + q = torch.randn(sum_seqlen_q, num_heads, head_dim).to(dtype) + k = torch.randn(sum_seqlen_k, num_heads_kv, head_dim).to(dtype) + v = torch.randn(sum_seqlen_k, num_heads_kv, head_dim_v).to(dtype) + + out_ref = flash_attn_non_varlen_ref( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + is_causal=is_causal, + enable_gqa=num_heads != num_heads_kv, + ) + + out = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqlens_q.max().item(), + seqlens_k.max().item(), + is_causal, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + + def _test_flash_attn_large_seq_causal_mask_once(self, seqlens): + dtype = torch.bfloat16 + num_heads = 8 + num_heads_kv = 2 + head_dim = 64 + + seqlens_t = torch.tensor(seqlens, dtype=torch.int32) + cu_seqlens = torch.zeros(len(seqlens) + 1, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens_t, 0) + total = cu_seqlens[-1].item() + max_seqlen = seqlens_t.max().item() + + q = torch.randn(total, num_heads, head_dim, dtype=dtype) + k = torch.randn(total, num_heads_kv, head_dim, dtype=dtype) + v = torch.randn(total, num_heads_kv, head_dim, dtype=dtype) + + out_ref = flash_attn_varlen_ref( + q, k, v, cu_seqlens, cu_seqlens, is_causal=True, enable_gqa=True + ) + out = flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, True + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + + def test_flash_attn_large_seq_causal_mask(self): + # Non-varlen path: single sequence, has_varlen_sequences returns False + # → dispatches to flash_attn_kernel_impl. + self._test_flash_attn_large_seq_causal_mask_once([5000]) + # Varlen path: sequences with different lengths, has_varlen_sequences + # returns True → dispatches to flash_attn_varlen_kernel_impl + self._test_flash_attn_large_seq_causal_mask_once([5000, 4999]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_gemm.py b/test/registered/cpu/test_gemm.py new file mode 100644 index 000000000000..c5ea056e7b8a --- /dev/null +++ b/test/registered/cpu/test_gemm.py @@ -0,0 +1,335 @@ +import itertools +import unittest + +# TODO: use interface in cpu.py +import torch +import torch.nn as nn +from utils import ( + convert_weight, + native_w8a8_per_token_matmul, + per_token_quant_int8, + precision, + unpack_and_dequant_awq, + unpack_and_dequant_gptq, +) + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class Mod(nn.Module): + def __init__(self, input_channel, output_channel, has_bias): + super(Mod, self).__init__() + self.linear = torch.nn.Linear(input_channel, output_channel, has_bias) + + def forward(self, x): + return self.linear(x) + + +class TestGemm(CustomTestCase): + M = [1, 101] + N = [16, 32 * 13] + K = [32 * 16] + has_bias = [False, True] + + M_int8 = [2, 128] + N_int8 = [32 * 12] + K_int8 = [32 * 17] + + M_fp8 = [1, 11] + N_fp8 = [128, 224] + K_fp8 = [512, 576] + + M_awq = [1, 32] + N_awq = [4096] + K_awq = [4096] + + M_gptq = [1, 32] + N_gptq = [4096] + K_gptq = [4096] + + def _bf16_gemm(self, M, N, K, has_bias): + + mat1 = torch.randn(M, K, dtype=torch.bfloat16) + mat2 = torch.randn(N, K, dtype=torch.bfloat16) + + ref = torch.matmul(mat1.float(), mat2.float().t()) + if has_bias: + bias = torch.randn(N, dtype=torch.float32) + ref.add_(bias.bfloat16()) + + ref = ref.bfloat16() + + out = torch.ops.sgl_kernel.weight_packed_linear( + mat1, mat2, bias if has_bias else None, False + ) + + packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2) + out2 = torch.ops.sgl_kernel.weight_packed_linear( + mat1, packed_mat2, bias if has_bias else None, True + ) + + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol) + + def test_bf16_gemm(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.has_bias, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + ): + self._bf16_gemm(*params) + + def _bf16_gemm_with_small_oc(self, M, N, K, has_bias, use_post_sigmul): + use_post_sigmul = use_post_sigmul and N == 1 + mat_mul = ( + None if not use_post_sigmul else torch.randn(M, 2 * K, dtype=torch.bfloat16) + ) + mat1 = torch.randn(M, K, dtype=torch.bfloat16) + mat2 = torch.randn(N, K, dtype=torch.bfloat16) + + ref = torch.nn.functional.linear(mat1, mat2) + if has_bias: + bias = torch.randn(N, dtype=torch.float32) + ref.add_(bias) + if use_post_sigmul: + ref = torch.nn.functional.sigmoid(ref) * mat_mul + out = torch.ops.sgl_kernel.fused_linear_sigmoid_mul( + mat1, + torch.ops.sgl_kernel.convert_weight_packed(mat2), + bias if has_bias else None, + True, + mat_mul if use_post_sigmul else None, + ) + else: + out = torch.ops.sgl_kernel.weight_packed_linear( + mat1, + torch.ops.sgl_kernel.convert_weight_packed(mat2), + bias if has_bias else None, + True, + ) + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) + + def test_bf16_gemm_with_small_oc(self): + for params in itertools.product( + [1, 8, 32, 1024], [12, 1], self.K, self.has_bias, [False, True] + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + use_post_sigmul=params[4], + ): + self._bf16_gemm_with_small_oc(*params) + + def _int8_gemm(self, M, N, K, has_bias): + dtype = torch.bfloat16 + A = torch.randn((M, K), dtype=dtype) / 10 + Aq, As = per_token_quant_int8(A) + + factor_for_scale = 1e-2 + int8_max = 127 + int8_min = -128 + + B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2 + Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + Bs = torch.rand(N) * factor_for_scale + + bias = torch.randn(N) if has_bias else None + ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype) + + atol = rtol = precision[ref_out.dtype] + + Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A) + out = torch.ops.sgl_kernel.int8_scaled_mm_cpu( + Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False + ) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + # test the fused version + fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant( + A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False + ) + torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol) + + def test_int8_gemm(self): + for params in itertools.product( + self.M_int8, + self.N_int8, + self.K_int8, + self.has_bias, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + ): + self._int8_gemm(*params) + + def _fp8_gemm(self, M, N, K, has_bias): + prepack = True + chunk = False + scale_block_size_N = 64 + scale_block_size_K = 128 + assert scale_block_size_N <= N + assert scale_block_size_K <= K + A_dtype = torch.bfloat16 + + model = Mod(K, N, has_bias).eval() + if chunk: + data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K) + else: + data = torch.randn(M, K, dtype=A_dtype) + + weight = model.linear.weight # (N, K) + + if has_bias: + bias = model.linear.bias + + fp8_weight, scales, dq_weight = convert_weight( + weight, [scale_block_size_N, scale_block_size_K], A_dtype + ) + + if has_bias: + ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype) + else: + ref = torch.matmul(data.to(A_dtype), dq_weight.T) + + if prepack: + fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight) + + opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu( + data, + fp8_weight, + scales, + [scale_block_size_N, scale_block_size_K], + bias if has_bias else None, + data.dtype, + prepack, + ) + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol) + + def test_fp8_gemm(self): + for params in itertools.product( + self.M_fp8, + self.N_fp8, + self.K_fp8, + self.has_bias, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + ): + self._fp8_gemm(*params) + + def _int4_awq_gemm(self, M, N, K, group_size, has_bias): + awq_weight = torch.randint(-128, 128, (K, N // 8)).to(torch.int) + awq_zero = torch.randint(0, 10, (K // group_size, N // 8)).to(torch.int) + awq_scales = torch.rand(int(K // group_size), N).to(torch.bfloat16) + bf16_weight, _ = unpack_and_dequant_awq( + awq_weight, awq_zero, awq_scales, 4, 128 + ) + if has_bias: + bias = torch.rand(bf16_weight.shape[0]).to(torch.float) + else: + bias = None + x = torch.rand(M, bf16_weight.size(-1)).to(torch.bfloat16) + ref_res = torch.nn.functional.linear( + x, bf16_weight, bias=bias.to(torch.bfloat16) if has_bias else None + ) + + packed_weight, packed_zero, packed_scales = ( + torch.ops.sgl_kernel.convert_weight_packed_scale_zp( + awq_weight, awq_zero, awq_scales, 0 + ) + ) + target_res = torch.ops.sgl_kernel.int4_scaled_mm_cpu( + x, + packed_weight, + packed_zero, + packed_scales, + bias, + ) + + atol = rtol = precision[ref_res.dtype] + torch.testing.assert_close(ref_res, target_res, atol=atol, rtol=rtol) + + def test_int4_awq_gemm(self): + for params in itertools.product( + self.M_awq, self.N_awq, self.K_awq, [128], self.has_bias + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + group_size=params[3], + has_bias=params[4], + ): + self._int4_awq_gemm(*params) + + def _int4_gptq_gemm(self, M, N, K, group_size, has_bias): + torch.manual_seed(127) + gptq_weight = torch.randint(-128, 128, (K // 8, N)).to(torch.int) + gptq_zero = torch.randint(0, 10, (K // group_size, N // 8)).to(torch.int) + gptq_scales = torch.rand(int(K // group_size), N).to(torch.bfloat16) // 10 + + bf16_weight = unpack_and_dequant_gptq(gptq_weight, gptq_zero, gptq_scales) + if has_bias: + bias = torch.rand(bf16_weight.shape[0]).to(torch.float) + else: + bias = None + x = torch.rand(M, bf16_weight.size(-1)).to(torch.bfloat16) + ref_res = torch.nn.functional.linear( + x, bf16_weight, bias=bias.to(torch.bfloat16) if has_bias else None + ) + + packed_weight, packed_zero, packed_scales = ( + torch.ops.sgl_kernel.convert_weight_packed_scale_zp( + gptq_weight, gptq_zero, gptq_scales, 1 + ) + ) + target_res = torch.ops.sgl_kernel.int4_scaled_mm_cpu( + x, + packed_weight, + packed_zero, + packed_scales, + bias, + ) + + atol = rtol = precision[ref_res.dtype] + torch.testing.assert_close(ref_res, target_res, atol=atol, rtol=rtol) + + def test_int4_gptq_gemm(self): + for params in itertools.product( + self.M_gptq, self.N_gptq, self.K_gptq, [128], self.has_bias + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + group_size=params[3], + has_bias=params[4], + ): + self._int4_gptq_gemm(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_intel_amx_attention_backend_a.py b/test/registered/cpu/test_intel_amx_attention_backend_a.py new file mode 100644 index 000000000000..ba60527d2e25 --- /dev/null +++ b/test/registered/cpu/test_intel_amx_attention_backend_a.py @@ -0,0 +1,76 @@ +""" +Usage: +python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_latency_default_model +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + intel_amx_benchmark, + is_in_ci, + popen_launch_server, +) + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +class TestIntelAMXAttnBackend(CustomTestCase): + + @intel_amx_benchmark( + extra_args=["--batch-size", "4", "--mem-fraction-static", "0.3"], + min_throughput=10, + ) + def test_latency_mla_model(self): + return DEFAULT_MLA_MODEL_NAME_FOR_TEST + + @intel_amx_benchmark( + extra_args=["--batch-size", "4", "--mem-fraction-static", "0.1"], + min_throughput=40, + ) + def test_latency_default_model(self): + return DEFAULT_MODEL_NAME_FOR_TEST + + def test_mmlu(self): + model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "intel_amx", + "--mem-fraction-static", + "0.3", + "--disable-radix", + "--trust-remote-code", + "--disable-overlap-schedule", + ], + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + metrics = run_eval(args) + if is_in_ci(): + self.assertGreater(metrics["score"], 0.45) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_intel_amx_attention_backend_b.py b/test/registered/cpu/test_intel_amx_attention_backend_b.py new file mode 100644 index 000000000000..104f417c5bd2 --- /dev/null +++ b/test/registered/cpu/test_intel_amx_attention_backend_b.py @@ -0,0 +1,38 @@ +""" +For intel_amx attention backend FP8 tests +Usage: +python3 -m unittest test_intel_amx_attention_backend_1.TestIntelAMXAttnBackendQuant.test_latency_fp8_qwen +""" + +import unittest + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8, + CustomTestCase, + intel_amx_benchmark, +) + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +class TestIntelAMXAttnBackendQuant(CustomTestCase): + + @intel_amx_benchmark( + extra_args=["--batch-size", "4", "--mem-fraction-static", "0.1"], + min_throughput=150, + ) + def test_latency_fp8_qwen(self): + return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 + + @intel_amx_benchmark( + extra_args=["--batch-size", "4", "--mem-fraction-static", "0.1"], + min_throughput=50, + ) + def test_latency_fp8_moe_model(self): + return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_intel_amx_attention_backend_c.py b/test/registered/cpu/test_intel_amx_attention_backend_c.py new file mode 100644 index 000000000000..391d21a2b095 --- /dev/null +++ b/test/registered/cpu/test_intel_amx_attention_backend_c.py @@ -0,0 +1,56 @@ +""" +For intel_amx attention backend w8a8 tests +Usage: +python3 -m unittest test_intel_amx_attention_backend_2.TestIntelAMXAttnBackendQuant.test_latency_w8a8_default_model +""" + +import unittest + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_W8A8, + DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, + CustomTestCase, + intel_amx_benchmark, +) + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +class TestIntelAMXAttnBackendQuant(CustomTestCase): + + @intel_amx_benchmark( + extra_args=[ + "--batch-size", + "4", + "--quantization", + "w8a8_int8", + "--mem-fraction-static", + "0.1", + ], + min_throughput=100, + ) + def test_latency_w8a8_default_model(self): + return DEFAULT_MODEL_NAME_FOR_TEST_W8A8 + + @intel_amx_benchmark( + extra_args=[ + "--batch-size", + "4", + "--quantization", + "w8a8_int8", + "--mem-fraction-static", + "0.9", + "--max-total-tokens", + "65536", + "--tp", + "6", + ], + min_throughput=100, + ) + def test_latency_w8a8_moe_model(self): + return DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_mamba.py b/test/registered/cpu/test_mamba.py new file mode 100644 index 000000000000..25275e908e77 --- /dev/null +++ b/test/registered/cpu/test_mamba.py @@ -0,0 +1,397 @@ +import unittest + +import torch +import torch.nn.functional as F +from torch.nn.functional import softplus +from utils import precision + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - num_heads % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_heads = num_heads + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + # for each chunk + for i in range(0, tot_heads // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2 + ) + @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape( + core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] + ) + core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def chunk_gated_delta_rule_update( + query, # [B, T, HK, K] + key, # [B, T, HK, K] + value, # [B, T, HV, V] + g, # [B, T, HV] + beta, # [B, T, HV] + cu_seqlens, # [N+1] + initial_state, # [N, HV, K, V] + use_qk_l2norm_in_kernel, # True +): + num_heads = query.shape[2] + num_value_heads = value.shape[2] + batch_size = initial_state.shape[0] + if num_value_heads // num_heads > 1: + query = query.repeat_interleave(num_value_heads // num_heads, dim=2) + key = key.repeat_interleave(num_value_heads // num_heads, dim=2) + output = torch.empty_like(value) + final_state = torch.empty_like(initial_state) + start_q = 0 + for i in range(batch_size): + end_q = cu_seqlens[i + 1] + core_attn_outi, last_recurrent_state = torch_chunk_gated_delta_rule( + query=query[:, start_q:end_q, :, :], + key=key[:, start_q:end_q, :, :], + value=value[:, start_q:end_q, :, :], + g=g[:, start_q:end_q, :], + beta=beta[:, start_q:end_q, :], + initial_state=initial_state[i], + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + output[:, start_q:end_q, :, :] = core_attn_outi + final_state[i] = last_recurrent_state + start_q = end_q + return output, final_state + + +def torch_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to( + value + ) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze( + -1 + ) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def sigmoid_gating_delta_rule_update( + query, + key, + value, + A_log, + a, + dt_bias, + b, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + beta = b.sigmoid() + g = -A_log.float().exp() * softplus(a.float() + dt_bias) + return torch_recurrent_gated_delta_rule( + query, + key, + value, + g.unsqueeze(0), + beta.unsqueeze(0), + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + +def torch_gdn_gating(A_log, a, b, dt_bias): + return -A_log.float().exp() * softplus(a.float() + dt_bias).unsqueeze( + 0 + ), b.sigmoid().unsqueeze(0) + + +class TestMambaAttention(CustomTestCase): + def test_chunk_gated_delta_rule(self): + B, L, HK, HV, EK, EV, N = 1, 100, 3, 6, 64, 64, 4 + seqlens = torch.randint(1, L, (N + 1,)) + seqlens[0] = 0 + cu_seqlens_ = torch.cumsum(seqlens, dim=0).to(torch.int32) + T = cu_seqlens_[-1].item() + query_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05 + key_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05 + value_ = torch.rand((B, T, HV, EV), dtype=torch.bfloat16) * 0.05 + g_ = torch.rand((B, T, HV), dtype=torch.float32) * 0.05 + beta_ = torch.rand((B, T, HV), dtype=torch.bfloat16) * 0.05 + initial_state_ = torch.rand((N, HV, EK, EV), dtype=torch.float32) * 0.05 + + for use_qk_l2norm_in_kernel in [True, False]: + core_attn_out_ref, last_recurrent_state_ref = chunk_gated_delta_rule_update( + query=query_, + key=key_, + value=value_, + g=g_, + beta=beta_, + cu_seqlens=cu_seqlens_, + initial_state=initial_state_, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + query = query_.clone() + key = key_.clone() + value = value_.clone() + g = g_.clone() + beta = beta_.clone() + cu_seqlens = cu_seqlens_.clone() + initial_state = initial_state_.clone() + + core_attn_out, last_recurrent_state = ( + torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( + query=query, + key=key, + value=value, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + ) + atol = rtol = precision[core_attn_out.dtype] + torch.testing.assert_close( + core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol + ) + + def test_fused_gdn_gating(self): + dims = [6, 32] + for dim in dims: + for A_log_dtype in [torch.float32, torch.bfloat16]: + A_log = torch.rand(dim, dtype=A_log_dtype) + a = torch.rand(1024, dim, dtype=torch.bfloat16) + b = torch.rand(1024, dim, dtype=torch.bfloat16) + dt_bias = torch.rand(dim, dtype=torch.bfloat16) + + g, beta = torch_gdn_gating(A_log, a, b, dt_bias) + g_sgl, beta_sgl = torch.ops.sgl_kernel.fused_gdn_gating_cpu( + A_log, a, b, dt_bias + ) + atol = rtol = precision[g.dtype] + atol2 = rtol2 = precision[beta.dtype] + torch.testing.assert_close(g, g_sgl, atol=atol, rtol=rtol) + torch.testing.assert_close(beta, beta_sgl, atol=atol2, rtol=rtol2) + + def test_fused_sigmoid_gating_delta_rule_update(self): + batch_size = 1 + num_value_heads = 32 + head_k_dim = 128 + head_v_dim = 128 + num_heads = 16 + seq_len = 1 + attn_tp_size = 1 + key_dim = head_k_dim * num_heads + value_dim = head_v_dim * num_value_heads + mixed_qkv_dim = (key_dim * 2 + value_dim) // attn_tp_size + mixed_qkv = torch.rand( + seq_len * batch_size, mixed_qkv_dim, dtype=torch.bfloat16 + ) + query, key, value = torch.split( + mixed_qkv, + [ + key_dim // attn_tp_size, + key_dim // attn_tp_size, + value_dim // attn_tp_size, + ], + dim=-1, + ) + query = query.view(1, seq_len, num_heads, head_k_dim) + key = key.view(1, seq_len, num_heads, head_k_dim) + value = value.view(1, seq_len, num_value_heads, head_v_dim) + A_log = torch.rand(num_value_heads, dtype=torch.float32) + a = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) + b = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) + dt_bias = torch.rand(num_value_heads, dtype=torch.bfloat16) + ssm_states = torch.rand( + 513, num_value_heads, head_k_dim, head_v_dim, dtype=torch.float32 + ) + cache_indices = torch.randint(0, 513, (batch_size,), dtype=torch.int32) + query_start_loc = torch.tensor([0, 1], dtype=torch.int32) + use_qk_l2norm_in_kernel = True + query_ref = query.clone() + key_ref = key.clone() + if num_value_heads // num_heads > 1: + query_ref = query_ref.repeat_interleave(num_value_heads // num_heads, dim=2) + key_ref = key_ref.repeat_interleave(num_value_heads // num_heads, dim=2) + for A_log_dtype in [torch.float32, torch.bfloat16]: + A_log = A_log.to(A_log_dtype) + core_attn_out_ref, last_recurrent_state_ref = ( + sigmoid_gating_delta_rule_update( + query_ref.transpose(0, 1), + key_ref.transpose(0, 1), + value.transpose(0, 1), + A_log, + a, + dt_bias, + b, + initial_state=ssm_states[cache_indices], + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + ) + core_attn_out = ( + torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu( + A_log=A_log, + dt_bias=dt_bias, + q=query, + k=key, + v=value, + a=a, + b=b, + initial_state_source=ssm_states, + initial_state_indices=cache_indices, + cu_seqlens=query_start_loc, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + ) + last_recurrent_state = ssm_states[cache_indices] + atol = rtol = precision[core_attn_out.dtype] + torch.testing.assert_close( + core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_mla.py b/test/registered/cpu/test_mla.py new file mode 100644 index 000000000000..f6714ced5ac5 --- /dev/null +++ b/test/registered/cpu/test_mla.py @@ -0,0 +1,158 @@ +import unittest + +import torch +from torch.nn.functional import scaled_dot_product_attention +from utils import precision + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestMLA(CustomTestCase): + def _run_sdpa_forward_decode( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key: torch.Tensor, + loc: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + # set kv cache + k_cache[loc] = key + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + seq_len_q = 1 + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out = ( + scaled_dot_product_attention( + per_req_query.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out + start_q, start_kv = end_q, end_kv + + return output + + def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len): + dtype = torch.bfloat16 + + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + logit_cap = 0.0 + num_kv_splits = 8 + enable_gqa = H_Q != H_KV + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype) + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype) + v_buffer = k_buffer.narrow(2, 0, D_V) + + key = torch.randn(B, H_KV, D, dtype=dtype) + value = key.narrow(2, 0, D_V) + # make sure no duplicates in loc + loc = torch.randperm(total_tokens)[:B].to(torch.int64) + + k_buffer2 = k_buffer.clone() + v_buffer2 = k_buffer2.narrow(2, 0, D_V) + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D_V, dtype=dtype) + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype) + + req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32) + b_req_idx = torch.arange(B).to(torch.int64) + b_seq_len = torch.full((B,), seq_len).to(torch.int64) + + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + ) + + torch.ops.sgl_kernel.decode_attention_cpu( + q, + k_buffer2, + v_buffer2, + o, + key, + value, + loc, + attn_logits, + req_to_token, + b_req_idx, + b_seq_len, + sm_scale, + logit_cap, + ) + + self._run_sdpa_forward_decode( + q, + o_grouped, + k_buffer, + v_buffer, + key, + loc, + req_to_token, + b_req_idx, + b_seq_len, + scaling=sm_scale, + enable_gqa=enable_gqa, + ) + + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_grouped.flatten(), dim=0 + ) + atol = rtol = precision[q.dtype] + self.assertGreater(cos_sim.item(), 0.99) + torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol) + torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol) + torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol) + + def test_grouped_decode_attention(self): + configs = [ + (1, 22, 1, 576, 512, 8 * 111), + (4, 22, 1, 576, 512, 8 * 128), + (40, 22, 1, 576, 512, 8 * 133), + ] + + for B, H_Q, H_KV, D, D_V, seqlen in configs: + self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_moe.py b/test/registered/cpu/test_moe.py new file mode 100644 index 000000000000..f3c817bec331 --- /dev/null +++ b/test/registered/cpu/test_moe.py @@ -0,0 +1,355 @@ +import itertools +import math +import unittest + +# TODO: use interface in cpu.py +import torch + +from sglang.srt.layers.amx_utils import CPUQuantMethod + +kernel = torch.ops.sgl_kernel + +torch.manual_seed(128) + +from utils import ( + BLOCK_K, + BLOCK_N, + factor_for_scale, + fp8_max, + fp8_min, + native_fp8_fused_moe, + precision, + scaled_weight, + torch_naive_fused_moe, + torch_w8a8_per_column_fused_moe, + unpack_and_dequant_awq, +) + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +def fused_moe(a, w1, w2, score, topk, renormalize, prepack): + + G = 1 + topk_group = 1 + + B, D = a.shape + topk_weights = torch.empty(B, topk, dtype=torch.float32) + topk_ids = torch.empty(B, topk, dtype=torch.int32) + topk_weights, topk_ids = kernel.grouped_topk_cpu( + a, score, topk, renormalize, G, topk_group, 0, None, None + ) + + packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 + packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2 + + inplace = True + return kernel.fused_experts_cpu( + a, + packed_w1, + packed_w2, + topk_weights, + topk_ids, + inplace, + CPUQuantMethod.UNQUANT, + None, + None, + None, + None, + None, + prepack, + ) + + +class TestFusedExperts(CustomTestCase): + M = [2, 114] + N = [32] + K = [32] + E = [4] + topk = [2] + renormalize = [False, True] + + M_int8 = [1, 39] + N_int8 = [128] + K_int8 = [256] + E_int8 = [8] + topk_int8 = [3] + + M_fp8 = [2, 121] + N_fp8 = [352, 512] + K_fp8 = [256, 320] + E_fp8 = [8] + topk_fp8 = [4] + + M_int4 = [1, 6] + N_int4 = [512] + K_int4 = [256] + E_int4 = [8] + topk_int4 = [4] + + def _bf16_moe(self, m, n, k, e, topk, renormalize): + dtype = torch.bfloat16 + prepack = True + + a = torch.randn((m, k), device="cpu", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10 + score = torch.randn((m, e), device="cpu", dtype=dtype) + + torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize) + fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack) + + atol = rtol = precision[torch_output.dtype] + torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol) + + def test_bf16_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.topk, + self.renormalize, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + e=params[3], + topk=params[4], + renormalize=params[5], + ): + self._bf16_moe(*params) + + def _int8_moe(self, M, N, K, E, topk): + dtype = torch.bfloat16 + prepack = True + + # Initialize int8 quantization parameters + int8_factor_for_scale = 1e-2 + int8_max = 127 + int8_min = -128 + + # Input tensor + # M * K + a = torch.randn((M, K), dtype=dtype) / math.sqrt(K) + + # Generate int8 weights + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 + w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 + w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + # Generate scale for each column (per-column quantization) + w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale + w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale + + # Calculate routing + score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + + ref_out = torch_w8a8_per_column_fused_moe( + a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk + ) + + inplace = True + packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 + packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2 + out = kernel.fused_experts_cpu( + a, + packed_w1, + packed_w2, + topk_weight, + topk_ids.to(torch.int32), + inplace, + CPUQuantMethod.INT8_W8A8, + w1_s, + w2_s, + None, + None, + None, + prepack, + ) + + atol = rtol = precision[ref_out.dtype] + # Increase the tolerance for large input shapes + if M > 35: + atol = rtol = 0.02 + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def test_int8_moe(self): + for params in itertools.product( + self.M_int8, + self.N_int8, + self.K_int8, + self.E_int8, + self.topk_int8, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + ): + self._int8_moe(*params) + + def _fp8_moe(self, M, N, K, E, topk): + dtype = torch.bfloat16 + + a = torch.randn(M, K, dtype=dtype) / math.sqrt(K) + + w1_fp32 = torch.randn(E, 2 * N, K) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = torch.randn(E, K, N) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w1s = ( + torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K)) + * factor_for_scale + ) + w2s = ( + torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K)) + * factor_for_scale + ) + + w1_scaled = scaled_weight(w1, w1s) + w2_scaled = scaled_weight(w2, w2s) + + score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + + w1 = kernel.convert_weight_packed(w1) + w2 = kernel.convert_weight_packed(w2) + + ref_out = native_fp8_fused_moe( + a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk + ) + out = kernel.fused_experts_cpu( + a, + w1, + w2, + topk_weight, + topk_ids.to(torch.int32), + False, + CPUQuantMethod.FP8_W8A16, + w1s, + w2s, + None, + None, + [BLOCK_N, BLOCK_K], + True, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol) + + def test_fp8_moe(self): + for params in itertools.product( + self.M_fp8, + self.N_fp8, + self.K_fp8, + self.E_fp8, + self.topk_fp8, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + ): + self._fp8_moe(*params) + + def _int4_moe(self, M, N, K, E, topk, group_size=128): + dtype = torch.bfloat16 + + a = torch.rand(M, K, dtype=dtype) / math.sqrt(K) + + awq_w13_weight = torch.randint(-127, 128, (E, K, 2 * N // 8)).to(torch.int) + awq_w13_zero = torch.randint(0, 10, (E, K // group_size, 2 * N // 8)).to( + torch.int + ) + awq_w13_scales = torch.rand(E, int(K // group_size), 2 * N).to(torch.bfloat16) + + awq_w2_weight = torch.randint(-127, 128, (E, N, K // 8)).to(torch.int) + awq_w2_zero = torch.randint(0, 10, (E, N // group_size, K // 8)).to(torch.int) + awq_w2_scales = torch.rand(E, int(N // group_size), K).to(torch.bfloat16) + bf16_w13_weight = [] + bf16_w2_weight = [] + for i in range(E): + bf16_w13_weight_i, _ = unpack_and_dequant_awq( + awq_w13_weight[i], awq_w13_zero[i], awq_w13_scales[i], 4, 128 + ) + bf16_w2_weight_i, _ = unpack_and_dequant_awq( + awq_w2_weight[i], awq_w2_zero[i], awq_w2_scales[i], 4, 128 + ) + bf16_w13_weight.append(bf16_w13_weight_i) + bf16_w2_weight.append(bf16_w2_weight_i) + bf16_w13_weight = torch.stack(bf16_w13_weight).detach() + bf16_w2_weight = torch.stack(bf16_w2_weight).detach() + + score = torch.rand((M, E), dtype=dtype) + + ref_out = torch_naive_fused_moe( + a, bf16_w13_weight, bf16_w2_weight, score, topk, False + ) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + awq_w13_weight_pack, awq_w13_zero_pack, awq_w13_scales_pack = ( + torch.ops.sgl_kernel.convert_weight_packed_scale_zp( + awq_w13_weight, awq_w13_zero, awq_w13_scales, 0 + ) + ) + awq_w2_weight_pack, awq_w2_zero_pack, awq_w2_scales_pack = ( + torch.ops.sgl_kernel.convert_weight_packed_scale_zp( + awq_w2_weight, awq_w2_zero, awq_w2_scales, 0 + ) + ) + + out = kernel.fused_experts_cpu( + a, + awq_w13_weight_pack, + awq_w2_weight_pack, + topk_weight, + topk_ids.to(torch.int32), + False, + CPUQuantMethod.INT4_W4A8, + awq_w13_scales_pack, + awq_w2_scales_pack, + awq_w13_zero_pack, + awq_w2_zero_pack, + None, + True, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol) + + def test_int4_moe(self): + for params in itertools.product( + self.M_int4, + self.N_int4, + self.K_int4, + self.E_int4, + self.topk_int4, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + ): + self._int4_moe(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_norm.py b/test/registered/cpu/test_norm.py new file mode 100644 index 000000000000..2da4d5117a99 --- /dev/null +++ b/test/registered/cpu/test_norm.py @@ -0,0 +1,435 @@ +import itertools +import unittest +from typing import Optional, Tuple, Union + +import torch +from utils import make_non_contiguous, parametrize, precision + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestNorm(CustomTestCase): + + def _forward_native( + self, + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float = 1e-6, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + x = x.to(orig_dtype) * weight + if residual is None: + return x + else: + return x, residual + + def _norm(self, x, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + def _gemma3_rmsnorm_native( + self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6 + ): + output = self._norm(x.float(), variance_epsilon) + output = output * (1.0 + weight.float()) + return output.type_as(x) + + def _gemma_rmsnorm_native( + self, + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float = 1e-6, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + x = x * (1.0 + weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + @parametrize( + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_norm(self, m, n, dtype): + + x = torch.randn([m, n], dtype=dtype) + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon) + ref_out = self._forward_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + ref_x = x.clone() + residual = torch.randn([m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( + x, residual, weight, variance_epsilon + ) + + ref_x, ref_residual = self._forward_native( + ref_x, weight, variance_epsilon, ref_residual + ) + + torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + @parametrize( + l=[1, 2], + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_norm_3d(self, l, m, n, dtype): + + x = torch.randn([l, m, n], dtype=dtype) + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon) + ref_out = self._forward_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + ref_x = x.clone() + residual = torch.randn([l, m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( + x, residual, weight, variance_epsilon + ) + + ref_x, ref_residual = self._forward_native( + ref_x, weight, variance_epsilon, ref_residual + ) + + torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + @parametrize( + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_l2norm(self, m, n, dtype): + + x = torch.randn([m, n], dtype=dtype) + hidden_size = x.size(-1) + fake_ones_weight = torch.ones(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon) + ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + @parametrize( + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_gemma_rmsnorm(self, m, n, dtype): + + x = torch.randn([m, n], dtype=dtype) + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.gemma_rmsnorm_cpu(x, weight, variance_epsilon) + ref_out = self._gemma_rmsnorm_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + ref_x = x.clone() + residual = torch.randn([m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu( + x, residual, weight, variance_epsilon + ) + + ref_x, ref_residual = self._gemma_rmsnorm_native( + ref_x, weight, variance_epsilon, ref_residual + ) + + torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + @parametrize( + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_gemma3_rmsnorm(self, m, n, dtype): + x_list = [ + torch.randn([m, n], dtype=dtype), + torch.randn([1, m, 2, n], dtype=dtype), + ] + for x in x_list: + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + out = torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, weight, variance_epsilon) + ref_out = self._gemma3_rmsnorm_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def _gemma4_rmsnorm_native( + self, + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float = 1e-6, + scale_shift: float = 0.0, + with_scale: bool = True, + ): + output = self._norm(x.float(), variance_epsilon) + if with_scale: + output = output * (weight.float() + scale_shift) + return output.type_as(x) + + @parametrize( + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_gemma4_rmsnorm(self, m, n, dtype): + for scale_shift, with_scale in [ + (0.0, True), + (1.0, True), + (0.0, False), + (1.0, False), + ]: + x_list = [ + torch.randn([m, n], dtype=dtype), + torch.randn([4, m, n], dtype=dtype), + ] + # Add non-block-contiguous 3D input + base = torch.randn([4, 2 * m, n], dtype=dtype) + x_list.append(base[:, :m, :]) + + for x in x_list: + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.gemma4_rmsnorm_cpu( + x, weight, variance_epsilon, scale_shift, with_scale + ) + ref_out = self._gemma4_rmsnorm_native( + x, weight, variance_epsilon, scale_shift, with_scale + ) + + atol = rtol = precision[ref_out.dtype] + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + +class TestFusedRMSNormGated(CustomTestCase): + M = [4096, 1024] + N = [4096, 4096 + 13] + dtype = [torch.float16, torch.bfloat16] + + def _forward_native( + self, + hidden_states: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float = 1e-6, + gate: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + hidden_states = weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * torch.nn.functional.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + def _norm_test(self, m, n, dtype): + + x = torch.randn([m, n], dtype=dtype) + x = make_non_contiguous(x) + batch_size = x.size(0) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + gate = torch.randn([batch_size, hidden_size], dtype=dtype) + + out = torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( + x, weight, gate, variance_epsilon + ) + ref_out = self._forward_native(x, weight, variance_epsilon, gate) + + atol = rtol = precision[ref_out.dtype] * 2 + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) + + def test_norm(self): + for params in itertools.product(self.M, self.N, self.dtype): + with self.subTest(m=params[0], n=params[1], dtype=params[2]): + self._norm_test(*params) + + +class TestLayerNorm(CustomTestCase): + + def _forward_native( + self, + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + residual: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0) + x = (x - mean) * torch.rsqrt(variance + variance_epsilon) + x = x * weight.to(torch.float32) + if bias is not None: + x = x + bias.to(torch.float32) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + @parametrize( + m=[4096, 1024], + n=[4096, 4109], + dtype=[torch.float16, torch.bfloat16], + ) + def test_norm_input_2d(self, m: int, n: int, dtype: torch.dtype) -> None: + x = torch.randn([m, n], dtype=dtype) + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + bias = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, None, variance_epsilon) + ref_ln_out = self._forward_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_ln_out.dtype] + torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) + + ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, bias, variance_epsilon) + ref_ln_out = self._forward_native( + x, weight, variance_epsilon, residual=None, bias=bias + ) + torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) + + residual = torch.randn([m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( + x, residual, weight, None, variance_epsilon + ) + ref_add_ln_out, ref_residual = self._forward_native( + x, weight, variance_epsilon, residual=ref_residual + ) + + torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + residual = torch.randn([m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( + x, residual, weight, bias, variance_epsilon + ) + ref_add_ln_out, ref_residual = self._forward_native( + x, weight, variance_epsilon, residual=ref_residual, bias=bias + ) + + torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + @parametrize( + l=[4096, 1024], + m=[1, 4], + n=[4096, 4109, 2304], + dtype=[torch.float16, torch.bfloat16], + ) + def test_norm_input_3d(self, l: int, m: int, n: int, dtype: torch.dtype) -> None: + x = torch.randn([l, m, n], dtype=dtype) + x = make_non_contiguous(x) + hidden_size = x.size(-1) + weight = torch.randn(hidden_size, dtype=dtype) + bias = torch.randn(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, None, variance_epsilon) + ref_ln_out = self._forward_native(x, weight, variance_epsilon) + + atol = rtol = precision[ref_ln_out.dtype] + torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) + + ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, bias, variance_epsilon) + ref_ln_out = self._forward_native( + x, weight, variance_epsilon, residual=None, bias=bias + ) + torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) + + residual = torch.randn([l, m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( + x, residual, weight, None, variance_epsilon + ) + ref_add_ln_out, ref_residual = self._forward_native( + x, weight, variance_epsilon, ref_residual + ) + + torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + residual = torch.randn([l, m, hidden_size], dtype=dtype) + ref_residual = residual.clone() + + add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( + x, residual, weight, bias, variance_epsilon + ) + ref_add_ln_out, ref_residual = self._forward_native( + x, weight, variance_epsilon, residual=ref_residual, bias=bias + ) + + torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_qkv_proj_with_rope.py b/test/registered/cpu/test_qkv_proj_with_rope.py new file mode 100644 index 000000000000..2fa291860f5d --- /dev/null +++ b/test/registered/cpu/test_qkv_proj_with_rope.py @@ -0,0 +1,443 @@ +import unittest + +import torch +from utils import ( + convert_weight, + native_w8a8_per_token_matmul, + per_token_quant_int8, + precision, +) + +from sglang.srt.layers.quantization.fp8_utils import input_to_float8 +from sglang.srt.layers.rotary_embedding.utils import apply_rotary_emb +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed +qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope +qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight +torch.manual_seed(1234) +# constants +kv_lora_rank = 512 +qk_head_dim = 192 +qk_nope_head_dim = 128 +qk_rope_head_dim = 64 +rotary_dim = qk_rope_head_dim +num_heads = 22 +q_lora_rank = 1536 +hidden_size = 7168 +B = 1 +eps = 1e-6 + + +def layernorm(x, weight, variance_epsilon=1e-6, residual=None): + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + return (x * weight).to(orig_dtype) + + +def rotary_emb(q_pe, k_pe, pos, cos_sin_cache): + orig_dtype = q_pe.dtype + q_pe = q_pe.float() + k_pe = k_pe.float() + cos_sin_cache = cos_sin_cache.float() + + query_rot = q_pe[..., :rotary_dim] + key_rot = k_pe[..., :rotary_dim] + cos_sin = cos_sin_cache[pos] + cos, sin = cos_sin.chunk(2, dim=-1) + query_rot = apply_rotary_emb(query_rot, cos, sin, False) + key_rot = apply_rotary_emb(key_rot, cos, sin, False) + return query_rot.to(orig_dtype), key_rot.to(orig_dtype) + + +def native_torch( + q_input, + hidden_states, + q_a_proj_weight, + norm_weight1, + q_b_proj_weight, + w_kc, + kv_a_proj_weight, + norm_weight2, + pos, + cos_sin_cache, +): + + q = torch.matmul(hidden_states, q_a_proj_weight.t()) + q = layernorm(q, norm_weight1) + q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim) + + q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) + + q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1) + latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t()) + v_input = latent_cache[..., :kv_lora_rank] + v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) + k_input[..., :kv_lora_rank] = v_input + k_pe = k_input[..., kv_lora_rank:] + + q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache) + q_input[..., kv_lora_rank:] = q_pe + k_input[..., kv_lora_rank:] = k_pe + + return q_input, k_input, v_input + + +def native_torch_int8( + q_input, + hidden_states, + w1_q, + w1_s, + norm_weight1, + w2_q, + w2_s, + w_kc, + w3_q, + w3_s, + norm_weight2, + pos, + cos_sin_cache, +): + + a_q, a_s = per_token_quant_int8(hidden_states) + q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16) + q = layernorm(q, norm_weight1) + + a_q, a_s = per_token_quant_int8(q) + q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view( + -1, num_heads, qk_head_dim + ) + + q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) + + q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1) + a_q, a_s = per_token_quant_int8(hidden_states) + latent_cache = native_w8a8_per_token_matmul( + a_q, w3_q, a_s, w3_s, None, torch.bfloat16 + ) + v_input = latent_cache[..., :kv_lora_rank] + v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) + k_input[..., :kv_lora_rank] = v_input + k_pe = k_input[..., kv_lora_rank:] + + q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache) + q_input[..., kv_lora_rank:] = q_pe + k_input[..., kv_lora_rank:] = k_pe + + return q_input, k_input, v_input + + +class TestQKVProjWithROPE(CustomTestCase): + def test_bf16_qkv_proj_with_rope(self): + dtype = torch.bfloat16 + hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size + q_input = torch.empty( + B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype + ) + q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 + norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) + q_b_proj_weight = ( + torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 + ) + w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 + kv_a_proj_weight = ( + torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 + ) + fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0) + norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) + pos = torch.randint(10, 100, (B,)) + cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) + q_ref, k_ref, v_ref = native_torch( + q_input, + hidden_states, + q_a_proj_weight, + norm_weight1, + q_b_proj_weight, + w_kc.transpose(1, 2), + kv_a_proj_weight, + norm_weight2, + pos, + cos_sin_cache, + ) + qa_packed = convert_weight_packed(q_a_proj_weight) + qb_packed = convert_weight_packed(q_b_proj_weight) + kva_packed = convert_weight_packed(kv_a_proj_weight) + wkc_packed = convert_weight_packed(w_kc) + fused_weight_packed = convert_weight_packed(fused_weight) + + q_out, k_out, v_out = qkv_proj_with_rope( + hidden_states, + qa_packed, + qb_packed, + kva_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + False, + None, + None, + None, + None, + True, + None, + ) + fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( + hidden_states, + fused_weight_packed, + qb_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + False, + None, + None, + None, + True, + None, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ) + atol = rtol = precision[q_ref.dtype] + torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol) + torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) + torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_q_out, q_out) + torch.testing.assert_close(fused_k_out, k_out) + torch.testing.assert_close(fused_v_out, v_out) + + def test_int8_qkv_proj_with_rope(self): + dtype = torch.bfloat16 + hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size + q_input = torch.empty( + B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype + ) + q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 + norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) + q_b_proj_weight = ( + torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 + ) + w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 + kv_a_proj_weight = ( + torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 + ) + norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) + pos = torch.randint(10, 100, (B,)) + cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) + + w1_q, w1_s = per_token_quant_int8(q_a_proj_weight) + w2_q, w2_s = per_token_quant_int8(q_b_proj_weight) + w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight) + q_ref, k_ref, v_ref = native_torch_int8( + q_input, + hidden_states, + w1_q, + w1_s, + norm_weight1, + w2_q, + w2_s, + w_kc.transpose(1, 2), + w3_q, + w3_s, + norm_weight2, + pos, + cos_sin_cache, + ) + w1_q_packed = convert_weight_packed(w1_q) + w2_q_packed = convert_weight_packed(w2_q) + w3_q_packed = convert_weight_packed(w3_q) + wkc_packed = convert_weight_packed(w_kc) + q_out, k_out, v_out = qkv_proj_with_rope( + hidden_states, + w1_q_packed, + w2_q_packed, + w3_q_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + True, + False, + w1_s, + w2_s, + w3_s, + None, + True, + None, + ) + fused_weight = torch.cat([w1_q, w3_q], dim=0) + fused_weight_s = torch.cat([w1_s, w3_s], dim=0) + w_fused_q_packed = convert_weight_packed(fused_weight) + fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( + hidden_states, + w_fused_q_packed, + w2_q_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + True, + False, + fused_weight_s, + w2_s, + None, + True, + None, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ) + atol = rtol = precision[q_ref.dtype] + torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol) + torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) + torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_q_out, q_out) + torch.testing.assert_close(fused_k_out, k_out) + torch.testing.assert_close(fused_v_out, v_out) + + def test_fp8_qkv_proj_with_rope(self): + dtype = torch.bfloat16 + hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size + q_input = torch.empty( + B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype + ) + q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 + norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) + q_b_proj_weight = ( + torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 + ) + w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 + w_kc_q, w_kc_s = input_to_float8(w_kc) + kv_a_proj_weight = ( + torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 + ) + norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) + pos = torch.randint(10, 100, (B,)) + cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) + + scale_block_size_N = 128 + scale_block_size_K = 128 + fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = ( + convert_weight( + q_a_proj_weight, + [scale_block_size_N, scale_block_size_K], + torch.bfloat16, + ) + ) + fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = ( + convert_weight( + q_b_proj_weight, + [scale_block_size_N, scale_block_size_K], + torch.bfloat16, + ) + ) + ( + fp8_kv_a_proj_with_mqa_weight, + kv_a_proj_with_mqa_weight_scale_inv, + kv_a_proj_with_mqa_weight_dq, + ) = convert_weight( + kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16 + ) + w_kc_dq = w_kc_q.to(torch.bfloat16) * w_kc_s + q_ref, k_ref, v_ref = native_torch( + q_input, + hidden_states, + q_a_proj_weight_dq, + norm_weight1, + q_b_proj_weight_dq, + w_kc_dq.transpose(1, 2), + kv_a_proj_with_mqa_weight_dq, + norm_weight2, + pos, + cos_sin_cache, + ) + fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight) + fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight) + fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed( + fp8_kv_a_proj_with_mqa_weight + ) + w_kc_q = convert_weight_packed(w_kc_q) + q_out, k_out, v_out = qkv_proj_with_rope( + hidden_states, + fp8_q_a_proj_weight_packed, + fp8_q_b_proj_weight_packed, + fp8_kv_a_proj_with_mqa_weight_packed, + w_kc_q, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + True, + q_a_proj_weight_scale_inv.float(), + q_b_proj_weight_scale_inv.float(), + kv_a_proj_with_mqa_weight_scale_inv.float(), + w_kc_s, + True, + [scale_block_size_N, scale_block_size_K], + ) + + fused_weight = torch.cat( + [fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0 + ) + fused_weight_s = torch.cat( + [q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0 + ) + fused_weight_packed = convert_weight_packed(fused_weight) + fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( + hidden_states, + fused_weight_packed, + fp8_q_b_proj_weight_packed, + w_kc_q, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + True, + fused_weight_s.float(), + q_b_proj_weight_scale_inv.float(), + w_kc_s, + True, + [scale_block_size_N, scale_block_size_K], + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ) + atol = rtol = precision[q_ref.dtype] + # Due to the change in multiplication order, the error is amplified. + # In the model, with fewer layers, this doesn't cause issues, but in + # tests with more layers, we need to enlarge the tolerance to pass the tests. + torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) + torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_q_out, q_out) + torch.testing.assert_close(fused_k_out, k_out) + torch.testing.assert_close(fused_v_out, v_out) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_qwen3.py b/test/registered/cpu/test_qwen3.py new file mode 100644 index 000000000000..e8edfeb649ae --- /dev/null +++ b/test/registered/cpu/test_qwen3.py @@ -0,0 +1,152 @@ +import unittest + +import torch +from utils import precision + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +def fix_query_key_value_ordering_reshape_cat( + mixed_qkvz, mixed_ba, num_k_heads, num_v_heads, attn_tp_size, head_k_dim, head_v_dim +): + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + num_k_heads // attn_tp_size, + ( + head_k_dim + + head_k_dim + + (head_v_dim + head_v_dim) * num_v_heads // num_k_heads + ), + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + num_k_heads // attn_tp_size, + 2 * num_v_heads // num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + head_k_dim, + head_k_dim, + (num_v_heads // num_k_heads * head_v_dim), + (num_v_heads // num_k_heads * head_v_dim), + ] + split_arg_list_ba = [ + num_v_heads // num_k_heads, + num_v_heads // num_k_heads, + ] + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, head_v_dim) + z = z.reshape(z.size(0), -1, head_v_dim) + b = b.reshape(b.size(0), num_v_heads // attn_tp_size) + a = a.reshape(a.size(0), num_v_heads // attn_tp_size) + query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + return mixed_qkv, z, b, a + + +def fix_query_key_value_ordering_reshape_cat_contiguous( + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, + key_dim: int, + value_dim: int, + num_v_heads: int, + head_v_dim: int, + attn_tp_size: int, +): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + k_tp = key_dim // attn_tp_size + v_tp = value_dim // attn_tp_size + nv_tp = num_v_heads // attn_tp_size + + # Directly split, no head group reshape + query, key, value, z = mixed_qkvz.split([k_tp, k_tp, v_tp, v_tp], dim=-1) + b, a = mixed_ba.split([nv_tp, nv_tp], dim=-1) + + # value / z reshape to (seq, num_v_heads/tp, head_v_dim) + value = value.reshape(value.size(0), -1, head_v_dim) + z = z.reshape(z.size(0), -1, head_v_dim) + query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + return mixed_qkv, z, b, a + + +class TestQwen3(CustomTestCase): + def test_fused_qkvzba_split_reshape_cat(self): + mixed_qkvz = torch.rand(1024, 12288, dtype=torch.bfloat16) + mixed_ba = torch.rand(1024, 64, dtype=torch.bfloat16) + head_k_dim = 128 + head_v_dim = 128 + num_v_heads = 32 + num_k_heads = 16 + attn_tp_size = 1 + mixed_qkv_ref, z_ref, b_ref, a_ref = fix_query_key_value_ordering_reshape_cat( + mixed_qkvz, + mixed_ba, + num_k_heads, + num_v_heads, + attn_tp_size, + head_k_dim, + head_v_dim, + ) + num_heads_qk = num_k_heads // attn_tp_size + num_heads_v = num_v_heads // attn_tp_size + mixed_qkv, z, b, a = torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( + mixed_qkvz, mixed_ba, num_heads_qk, num_heads_v, head_k_dim, head_v_dim + ) + atol = rtol = precision[mixed_qkv.dtype] + torch.testing.assert_close(mixed_qkv, mixed_qkv_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(b, b_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(a, a_ref, atol=atol, rtol=rtol) + + def test_fused_qkvzba_split_reshape_cat_contiguous(self): + mixed_qkvz = torch.rand(1, 12288, dtype=torch.bfloat16) + mixed_ba = torch.rand(1, 64, dtype=torch.bfloat16) + head_k_dim = 128 + head_v_dim = 128 + num_v_heads = 32 + num_k_heads = 16 + attn_tp_size = 1 + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + mixed_qkv_ref, z_ref, b_ref, a_ref = ( + fix_query_key_value_ordering_reshape_cat_contiguous( + mixed_qkvz, + mixed_ba, + key_dim, + value_dim, + num_v_heads, + head_v_dim, + attn_tp_size, + ) + ) + num_heads_qk = num_k_heads // attn_tp_size + num_heads_v = num_v_heads // attn_tp_size + mixed_qkv, z, b, a = ( + torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_contiguous_cpu( + mixed_qkvz, mixed_ba, num_heads_qk, num_heads_v, head_k_dim, head_v_dim + ) + ) + atol = rtol = precision[mixed_qkv.dtype] + torch.testing.assert_close(mixed_qkv, mixed_qkv_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(b, b_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(a, a_ref, atol=atol, rtol=rtol) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_rope.py b/test/registered/cpu/test_rope.py new file mode 100644 index 000000000000..7246214d2dac --- /dev/null +++ b/test/registered/cpu/test_rope.py @@ -0,0 +1,285 @@ +import unittest + +import torch +from utils import precision + +from sglang.srt.layers.rotary_embedding import ( + MRotaryEmbedding, + RotaryEmbedding, +) +from sglang.srt.layers.rotary_embedding.rope_variant import ( + DeepseekScalingRotaryEmbedding, + apply_rotary_pos_emb_native, +) +from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestROPE(CustomTestCase): + def test_mrope(self): + torch.manual_seed(100) + head_size = 128 + seq_len = 512 + num_heads = 16 + num_kv_heads = 1 + rotary_dim = 128 + max_pos = 262144 + base = 5000000 + is_neox_style = True + dtype = torch.bfloat16 + mrope_section = [24, 20, 20] + mrope_interleaved = True + positions_mrope = torch.randint(0, max_pos, (3, seq_len)) + positions_text = torch.randint(0, max_pos, (seq_len,)) + set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) + + test_config = [ + # (dtype, is_neox_stype, mrope_interleaved, positions, mrope_section) + (torch.bfloat16, False, True, positions_mrope, mrope_section), + (torch.bfloat16, False, False, positions_mrope, mrope_section), + (torch.bfloat16, False, False, positions_text, None), + (torch.bfloat16, True, True, positions_mrope, mrope_section), + (torch.bfloat16, True, False, positions_mrope, mrope_section), + (torch.bfloat16, True, False, positions_text, None), + ] + for ( + dtype, + is_neox_style, + mrope_interleaved, + positions, + mrope_section, + ) in test_config: + rope = MRotaryEmbedding( + head_size, + rotary_dim, + max_pos, + base, + is_neox_style, + dtype, + mrope_section, + mrope_interleaved, + ) + enable_autocast = True + + with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast): + q = torch.randn(seq_len, num_heads * head_size, dtype=dtype) + q_clone = q.clone() + k = torch.randn(seq_len, num_kv_heads * head_size, dtype=dtype) + k_clone = k.clone() + + # ref kernel + q_ref, k_ref = rope.forward_native( + query=q, + key=k, + positions=positions, + ) + # fused rope kernel + q_sgl, k_sgl = torch.ops.sgl_kernel.multimodal_rotary_embedding_cpu( + positions, + q_clone, + k_clone, + rope.head_size, + rope.cos_sin_cache, + rope.mrope_section, + rope.mrope_interleaved, + is_neox_style, + ) + atol = rtol = precision[q_ref.dtype] + torch.testing.assert_close(q_ref, q_sgl, atol=atol, rtol=rtol) + torch.testing.assert_close(k_ref, k_sgl, atol=atol, rtol=rtol) + + def test_deepseek_v2_rope(self): + num_head = 16 + seq_len = 1024 + q_head_dim = 192 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + max_pos = 256 + k_dim = 576 + rotary_dim = 64 + is_neox_style = False + set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) + + # Create cos_sin_cache + freqs = torch.rand(max_pos, qk_rope_head_dim // 2) + cos = freqs.cos() * 0.7 + sin = freqs.sin() * 0.7 + cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16) + positions = torch.randint(0, max_pos, (seq_len,)) + + rope = DeepseekScalingRotaryEmbedding( + qk_rope_head_dim, + rotary_dim, + max_pos, + 16, # not used since cos_sin_cache is provided + is_neox_style, + 1.0, + torch.bfloat16, + device="cpu", + ) + rope.register_buffer("cos_sin_cache", cos_sin_cache) + + for dtype in [torch.bfloat16]: + enable_autocast = True + + with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast): + q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype) + q_clone = q.clone() + k = torch.randn(seq_len, 1, k_dim, dtype=dtype) + k_clone = k.clone() + _, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + _, q_pe_clone = q_clone.split( + [qk_nope_head_dim, qk_rope_head_dim], dim=-1 + ) + k_pe = k[:, :, k_dim - qk_rope_head_dim :] + k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :] + + # ref kernel + q_pe, k_pe = rope.forward_native( + query=q_pe, + key=k_pe, + positions=positions, + ) + + # fused rope kernel + q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu( + positions, + q_pe_clone, + k_pe_clone, + rope.head_size, + cos_sin_cache, + False, + ) + + atol = rtol = precision[q_pe.dtype] + torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol) + torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol) + torch.testing.assert_close(k_pe, k_pe_clone) + + def test_origin_rope(self): + def single_test( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + dims: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, + ): + set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) + torch.manual_seed(100) + rope_ref = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + ).to(device) + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, + num_q_heads * head_size, + dtype=dtype, + device=device, + ) + key = torch.randn( + batch_size * seq_len, + num_kv_heads * head_size, + dtype=dtype, + device=device, + ) + if dims == 4: + query = query.view(batch_size, seq_len, num_q_heads, head_size) + key = key.view(batch_size, seq_len, num_kv_heads, head_size) + query_ref, key_ref = query.clone(), key.clone() + query_cpu, key_cpu = query.clone(), key.clone() + + query_ref_out, key_ref_out = rope_ref.forward_native( + pos_ids, query_ref, key_ref + ) + query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu( + pos_ids, + query_cpu, + key_cpu, + rope_ref.head_size, + rope_ref.cos_sin_cache.to(query.dtype), + rope_ref.is_neox_style, + ) + torch.testing.assert_close( + query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2) + + test_config = [ + (64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8), + (512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2), + ] + + for ( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + ) in test_config: + for dim in [2, 4]: + single_test( + head_size, + rotary_dim, + max_position_embeddings, + base, + dim, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + ) + + def test_apply_rotary_pos_emb(self): + num_tokens = 1024 + num_heads = 8 + head_size = 72 + qkv = torch.randn(num_tokens, num_heads * head_size * 3).to(torch.bfloat16) + query, key, _ = qkv.split( + [num_heads * head_size, num_heads * head_size, num_heads * head_size], + dim=-1, + ) + query = query.view(num_tokens, num_heads, head_size) + key = key.view(num_tokens, num_heads, head_size) + for sincos_dtype in [torch.float32, torch.bfloat16]: + cos = torch.rand(num_tokens, head_size).to(sincos_dtype) + sin = torch.rand(num_tokens, head_size).to(sincos_dtype) + q_out_ref, k_out_ref = apply_rotary_pos_emb_native(query, key, cos, sin) + q_out_sgl, k_out_sgl = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu( + query, key, cos, sin + ) + torch.testing.assert_close(q_out_ref, q_out_sgl, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_out_ref, k_out_sgl, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_server_args_backend.py b/test/registered/cpu/test_server_args_backend.py new file mode 100644 index 000000000000..22448a7f471d --- /dev/null +++ b/test/registered/cpu/test_server_args_backend.py @@ -0,0 +1,38 @@ +import unittest +from unittest.mock import patch + +from sglang.srt.server_args import ServerArgs +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + + +class TestServerArgsCPUBackend(unittest.TestCase): + def _make_server_args(self, attention_backend=None): + server_args = ServerArgs.__new__(ServerArgs) + server_args.device = "cpu" + server_args.attention_backend = attention_backend + server_args.sampling_backend = None + return server_args + + @patch("sglang.srt.server_args.is_host_cpu_arm64", return_value=True) + def test_arm_cpu_defaults_to_torch_native(self, _mock_is_arm64): + server_args = self._make_server_args() + + ServerArgs._handle_cpu_backends(server_args) + + self.assertEqual(server_args.attention_backend, "torch_native") + self.assertEqual(server_args.sampling_backend, "pytorch") + + @patch("sglang.srt.server_args.is_host_cpu_arm64", return_value=False) + def test_x86_cpu_defaults_to_intel_amx(self, _mock_is_arm64): + server_args = self._make_server_args() + + ServerArgs._handle_cpu_backends(server_args) + + self.assertEqual(server_args.attention_backend, "intel_amx") + self.assertEqual(server_args.sampling_backend, "pytorch") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_shared_expert.py b/test/registered/cpu/test_shared_expert.py new file mode 100644 index 000000000000..18b3cef78acc --- /dev/null +++ b/test/registered/cpu/test_shared_expert.py @@ -0,0 +1,235 @@ +import itertools +import math +import unittest + +import torch +from utils import ( + BLOCK_K, + BLOCK_N, + factor_for_scale, + fp8_max, + fp8_min, + per_token_quant_int8, + precision, + scaled_weight, + torch_naive_moe, + torch_w8a8_per_column_moe, +) + +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +class TestSharedExpert(CustomTestCase): + M = [2, 121] + N = [32, 32 * 4] + K = [32, 32 * 2] + routed_scaling_factor = [16] + apply_scaling_factor = [True, False] + + M_fp8 = [2, 12] + N_fp8 = [512] + K_fp8 = [256] + + def _bf16_shared_expert(self, m, n, k, routed_scaling_factor, apply_scaling_factor): + dtype = torch.bfloat16 + + hidden_states = torch.randn(m, k, dtype=dtype) / k + w1 = torch.randn(2 * n, k, dtype=dtype) + w2 = torch.randn(k, n, dtype=dtype) + fused_output = ( + torch.randn(m, k, dtype=dtype) / k if apply_scaling_factor else None + ) + routed_scaling_factor = routed_scaling_factor if apply_scaling_factor else None + + # fused moe mutates content in hs + hidden_states2 = hidden_states.clone() + + # bfloat16 + ref = torch_naive_moe( + hidden_states, + w1, + w2, + fused_output, + routed_scaling_factor, + output_dtype=dtype, + ) + out = torch.ops.sgl_kernel.shared_expert_cpu( + hidden_states2, + w1, + w2, + fused_output, + routed_scaling_factor, + True, + False, + False, + None, + None, + None, + False, + ) + + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) + + def test_bf16_shared_expert(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.routed_scaling_factor, + self.apply_scaling_factor, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + routed_scaling_factor=params[3], + apply_scaling_factor=params[4], + ): + self._bf16_shared_expert(*params) + + def _int8_shared_expert(self, m, n, k, routed_scaling_factor, apply_scaling_factor): + dtype = torch.bfloat16 + + hidden_states = torch.randn(m, k, dtype=dtype) / k + w1 = torch.randn(2 * n, k, dtype=dtype) + w2 = torch.randn(k, n, dtype=dtype) + fused_output = ( + torch.randn(m, k, dtype=dtype) / k if apply_scaling_factor else None + ) + routed_scaling_factor = routed_scaling_factor if apply_scaling_factor else None + + # fused moe mutates content in hs + hidden_states2 = hidden_states.clone() + + w1_q, w1_s = per_token_quant_int8(w1) + w2_q, w2_s = per_token_quant_int8(w2) + ref = torch_w8a8_per_column_moe( + hidden_states, + w1_q, + w2_q, + w1_s, + w2_s, + fused_output, + routed_scaling_factor, + ) + out = torch.ops.sgl_kernel.shared_expert_cpu( + hidden_states2, + w1_q, + w2_q, + fused_output, + routed_scaling_factor, + True, + True, + False, + w1_s, + w2_s, + None, + False, + ) + + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) + + def test_int8_shared_expert(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.routed_scaling_factor, + self.apply_scaling_factor, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + routed_scaling_factor=params[3], + apply_scaling_factor=params[4], + ): + self._int8_shared_expert(*params) + + def _fp8_shared_expert(self, m, n, k, routed_scaling_factor, apply_scaling_factor): + dtype = torch.bfloat16 + + hidden_states = torch.randn(m, k, dtype=dtype) / math.sqrt(k) + + w1_fp32 = torch.randn(1, 2 * n, k) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = torch.randn(1, k, n) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w1s = torch.randn(1, 2 * n // BLOCK_N, k // BLOCK_K) * factor_for_scale + w2s = torch.randn(1, k // BLOCK_N, n // BLOCK_K) * factor_for_scale + + w1_scaled = scaled_weight(w1, w1s).view(2 * n, k) + w2_scaled = scaled_weight(w2, w2s).view(k, n) + + # change back to 2D + w1, w2 = w1.squeeze(0), w2.squeeze(0) + w1s, w2s = w1s.squeeze(0), w2s.squeeze(0) + w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0) + + fused_output = ( + torch.randn(m, k, dtype=dtype) / math.sqrt(k) + if apply_scaling_factor + else None + ) + routed_scaling_factor = routed_scaling_factor if apply_scaling_factor else None + hidden_states2 = hidden_states.clone() + + # ref with bfloat16 + ref = torch_naive_moe( + hidden_states, + w1_scaled, + w2_scaled, + fused_output, + routed_scaling_factor, + output_dtype=dtype, + ) + + w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K] + w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N] + out = torch.ops.sgl_kernel.shared_expert_cpu( + hidden_states2, + w1, + w2, + fused_output, + routed_scaling_factor, + True, + False, + True, + w1s, + w2s, + [BLOCK_N, BLOCK_K], + True, + ) + + atol = rtol = precision[ref.dtype] + torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) + + def test_fp8_shared_expert(self): + for params in itertools.product( + self.M_fp8, + self.N_fp8, + self.K_fp8, + self.routed_scaling_factor, + self.apply_scaling_factor, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + routed_scaling_factor=params[3], + apply_scaling_factor=params[4], + ): + self._fp8_shared_expert(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/test_topk.py b/test/registered/cpu/test_topk.py new file mode 100644 index 000000000000..2d6d2bbbeb0f --- /dev/null +++ b/test/registered/cpu/test_topk.py @@ -0,0 +1,223 @@ +import unittest + +import torch + +from sglang.srt.layers.moe.topk import ( + biased_grouped_topk_impl as native_biased_grouped_topk, +) +from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk +from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk +from sglang.srt.models.llama4 import Llama4MoE +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-b-test-cpu") + +torch.manual_seed(1234) + + +# This is used by the Deepseek-V2 model +class TestGroupedTopK(CustomTestCase): + def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): + torch.manual_seed(1234) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + + ref_topk_weights, ref_topk_ids = native_grouped_topk( + hidden_states.float(), + gating_output.float(), + topk, + renormalize, + G, + topk_group, + ) + + # fused version + topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu( + hidden_states, + gating_output, + topk, + renormalize, + G, + topk_group, + 0, + None, + None, + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_grouped_topk(self): + for renormalize in [True, False]: + self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16) + self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16) + self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16) + self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16) + self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16) + self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16) + self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16) + + +# DeepSeek V2/V3/R1 uses biased_grouped_top +class TestBiasedGroupedTopK(CustomTestCase): + def _run_single_test( + self, + M, + E, + G, + topk, + topk_group, + renormalize, + gating_dtype, + bias_dtype, + routed_scaling_factor, + ): + torch.manual_seed(1024) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=torch.bfloat16) + gating_output = torch.randn(M, E, dtype=gating_dtype) * 2 * M + correction_bias = torch.randn(E, dtype=bias_dtype) + + ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( + hidden_states.float(), + gating_output.float(), + correction_bias.float(), + topk, + renormalize, + G, + topk_group, + ) + ref_topk_weights = ( + ref_topk_weights * routed_scaling_factor + if routed_scaling_factor is not None + else ref_topk_weights + ) + # fused version + topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu( + hidden_states, + gating_output, + correction_bias, + topk, + renormalize, + G, + topk_group, + 0, + routed_scaling_factor, + None, + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_biased_grouped_topk(self): + for renormalize in [False]: + for bias_dtype in [torch.float32, torch.bfloat16]: + for gating_dtype in [torch.float32, torch.bfloat16]: + for routed_scaling_factor in [None, 1.125]: + for E_num in [128, 192, 256, 384]: + self._run_single_test( + 34, + E_num, + 8, + 8, + 2, + renormalize, + gating_dtype, + bias_dtype, + routed_scaling_factor, + ) + + +class TestTopK(CustomTestCase): + def _run_single_test(self, M, E, topk, renormalize, dtype): + torch.manual_seed(1998) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + + ref_topk_weights, ref_topk_ids = native_fused_topk( + hidden_states.float(), + gating_output.float(), + topk, + renormalize, + ) + + # fused version + topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( + hidden_states, gating_output, topk, renormalize + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_topk(self): + for renormalize in [True, False]: + self._run_single_test(123, 8, 2, renormalize, torch.bfloat16) + self._run_single_test(123, 16, 3, renormalize, torch.bfloat16) + self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) + self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) + self._run_single_test(123, 64, 6, renormalize, torch.bfloat16) + self._run_single_test(123, 256, 4, renormalize, torch.bfloat16) + self._run_single_test(123, 160, 6, renormalize, torch.bfloat16) + + +class TestCustomTopK(CustomTestCase): + def _run_single_test( + self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f + ): + torch.manual_seed(16) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + + ref_topk_weights, ref_topk_ids = native_custom_f( + hidden_states.float(), + gating_output.float(), + topk, + renormalize, + ) + + # fused version + topk_weights, topk_ids = fused_custom_f( + hidden_states, gating_output, topk, renormalize + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_custom_topk(self): + test_custom_functions = [ + (Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu) + ] + for native_custom_f, fused_custom_f in test_custom_functions: + self._run_single_test( + 123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f + ) + self._run_single_test( + 123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f + ) + self._run_single_test( + 123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/cpu/utils.py b/test/registered/cpu/utils.py new file mode 100644 index 000000000000..57e7e74c103c --- /dev/null +++ b/test/registered/cpu/utils.py @@ -0,0 +1,440 @@ +import itertools +import math + +import torch +import torch.nn.functional as F + +precision = { + torch.bfloat16: 1e-2, + torch.float16: 1e-3, + torch.float32: 1e-5, +} + + +BLOCK_N, BLOCK_K = 64, 128 +factor_for_scale = 1e-3 +fp8_max, fp8_min = 400, -400 + + +def parametrize(**params): + def decorator(func): + def wrapper(self): + for combo in itertools.product(*params.values()): + kwargs = dict(zip(params.keys(), combo)) + with self.subTest(**kwargs): + func(self, **kwargs) + + return wrapper + + return decorator + + +def SiluAndMul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] + + +def per_token_quant_int8(x): + x = x.float() + absmax = x.abs().max(dim=-1).values + absmax = absmax.clamp_min(1e-10).unsqueeze(-1) + scale_x = absmax / 127 + x_q = x.mul(127 / absmax) + x_q = torch.round(x_q).to(torch.int8) + + return x_q, scale_x + + +def convert_weight(weight, scale_block_size, A_dtype): + N, K = weight.size() + fp8_max = 448.0 + scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128) + + pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N + pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K + + if pad_N > 0 or pad_K > 0: + weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N)) + + weight_blocks = weight.view( + math.ceil(N / scale_block_size_N), + scale_block_size_N, + math.ceil(K / scale_block_size_K), + scale_block_size_K, + ) # (8, 128, 8, 128) + weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128) + + # Step 2: compute per-block max abs values → scale + abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1) + scales = abs_max / fp8_max + scales = torch.where( + scales == 0, torch.ones_like(scales), scales + ) # avoid division by zero + + q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn) + q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous() + + if pad_N > 0 or pad_K > 0: + q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K) + q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous() + else: + q_fp8_reshape = q_fp8_reshape.view(N, K) + + dq_weight = q_fp8.float() * scales + dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128) + + if pad_N > 0 or pad_K > 0: + w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype) + w_dq = w_dq[:N, :K].contiguous() + else: + w_dq = dq_weight.view(N, K).to(A_dtype) + + scales = scales.view( + math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K) + ) + + return q_fp8_reshape, scales, w_dq + + +def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16): + """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K,) + A = A.reshape(M, N) + + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + if bias is not None: + C.add_(bias.view(1, -1)) + + return C.reshape(origin_C_shape).to(output_dtype) + + +def torch_naive_moe(a, w1, w2, b, routed_scaling_factor, output_dtype=torch.bfloat16): + + a = a.to(torch.float32) + w1 = w1.to(torch.float32) + w2 = w2.to(torch.float32) + b = b.to(torch.float32) if b is not None else None + + ic1 = torch.matmul(a, w1.transpose(0, 1)) + ic2 = SiluAndMul(ic1) + ic3 = torch.matmul(ic2, w2.transpose(0, 1)) + + out = ic3 if b is None else ic3 + b * routed_scaling_factor + + return out.to(output_dtype) + + +def torch_w8a8_per_column_moe( + a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor, output_dtype=torch.bfloat16 +): + + a = a.to(torch.float32) + b = b.to(torch.float32) if b is not None else None + + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + + ic1 = native_w8a8_per_token_matmul( + a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32 + ) + ic2 = SiluAndMul(ic1) + + a1_q, a1_s = per_token_quant_int8(ic2) + ic3 = native_w8a8_per_token_matmul( + a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32 + ) + + out = ic3 if b is None else ic3 + b * routed_scaling_factor + + return out.to(output_dtype) + + +def scaled_weight(weight, scales): + E, N, K = weight.shape + pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N + pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K + + if pad_N > 0 or pad_K > 0: + weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N)) + + weight_block = ( + weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K) + .permute(0, 1, 3, 2, 4) + .float() + .contiguous() + ) + + weight_scaled = ( + ( + weight_block + * scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1) + ) + .permute(0, 1, 3, 2, 4) + .contiguous() + ) + if pad_N > 0 or pad_K > 0: + weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K) + weight_scaled = weight_scaled[..., :N, :K].contiguous() + else: + weight_scaled = weight_scaled.view(E, N, K) + return weight_scaled + + +def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + + if renormalize: + topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) + + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk): + """This function performs fused moe with per-column int8 quantization using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device) + + # Calculate routing + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul( + a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + bias=None, + output_dtype=torch.float32, + ) + # Activation function + act_out = SiluAndMul(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = per_token_quant_int8(act_out) + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul( + act_out_q, + w2[i], + act_out_s, + w2_s[i], + bias=None, + output_dtype=torch.float32, + ) + # Apply routing weights and sum + return ( + (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)) + .sum(dim=1) + .to(a.dtype) + ) + + +def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float() + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device) + + # Calculate routing + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1)) + ic1 = SiluAndMul(ic0) + out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1)) + + return ( + (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)) + .sum(dim=1) + .to(a.dtype) + ) + + +def make_non_contiguous(x: torch.Tensor) -> torch.Tensor: + """ + Make a tensor non-contiguous by slicing it via last dimension. + """ + last_dim = x.shape[-1] + return x[..., : last_dim // 2] if x.is_contiguous() else x + + +def awq_reverse_reorder_int_tensor(int_tensor, bits: int): + assert bits == 4 + + int_tensor = int_tensor.T.contiguous() + compress_ratio = 32 // bits + assert int_tensor.shape[-1] % compress_ratio == 0 + + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_tensor = torch.tensor( + order_map, dtype=torch.int32, device=int_tensor.device + ).reshape(1, -1) + order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1) + order_tensor = order_tensor + torch.arange( + 0, + int_tensor.shape[1], + compress_ratio, + dtype=torch.int32, + device=int_tensor.device, + ).reshape(-1, 1) + order_tensor = order_tensor.reshape(-1) + + reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor] + reverse_order_tensor = reverse_order_tensor[order_tensor] + int_tensor = int_tensor[:, reverse_order_tensor] + return int_tensor + + +def unpack_and_dequant_awq( + awq_qweight: torch.Tensor, + awq_qzeros: torch.Tensor, + awq_scales: torch.Tensor, + bits: int, + group_size: int, +): + """ + Args: + awq_qweight (`torch.LongTensor`): + Expected shape: (in_features, out_features // (32 // bits)) + awq_qzeros (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features // (32 // bits)) + awq_scales (`torch.LongTensor`): + Expected shape: (in_features // group_size, out_features) + + Returns: + fp16_weight (`torch.LongTensor`): + With shape (in_features, out_features). + zeros (`torch.LongTensor`): + With shape (in_features // group_size, out_features). + """ + assert bits == 4 + + qzeros = awq_qzeros + qweight = awq_qweight + qweight = qweight.T.contiguous() + + scales = awq_scales + scales = scales.reshape(-1, 1, scales.shape[-1]) + + infeatures = awq_qweight.shape[0] + + wf = torch.tensor( + list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device + ).unsqueeze(0) + zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to( + torch.int16 if bits == 8 else torch.int8 + ) + + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + + zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(weight, (2**bits) - 1, out=weight) + weight = weight.reshape(-1, group_size, weight.shape[2]) + + weight = weight.view(-1, weight.shape[-1]) + zeros = zeros.view(-1, zeros.shape[-1]) + + zeros = zeros.T.contiguous() + zeros = awq_reverse_reorder_int_tensor(zeros, bits) + weight = awq_reverse_reorder_int_tensor(weight, bits) + + # Dequantize weights. + scales = awq_scales + zeros = zeros.contiguous() + scale_zeros = zeros * scales + + g_idx = torch.tensor( + [i // group_size for i in range(infeatures)], dtype=torch.int32 + ) + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx].to(torch.bfloat16) + + qdq_weight_T = weight * scale_mat - scale_zeros_mat.to(torch.bfloat16) + + fp16_weight = qdq_weight_T.T + + return fp16_weight, zeros + + +def unpack_4bit_to_32bit_signed(qweight, qzeros): + # Unpack 4-bit values and interpret them as signed integers + unpacked_weights = torch.zeros( + (qweight.shape[0] * 8, qweight.shape[1]), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * 8), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for row in range(unpacked_weights.shape[0]): + i = row % 8 + unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF + + for col in range(unpacked_zeros.shape[1]): + i = col % 8 + unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF + + return unpacked_weights, unpacked_zeros + 1 + + +def unpack_and_dequant_gptq(qweight, qzeros, scales): + unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) + group_size = unpacked_qweight.shape[0] // scales.shape[0] + scales = scales.repeat_interleave(group_size, dim=0) + unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) + unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales + + return unpacked_qweight.T diff --git a/test/run_suite.py b/test/run_suite.py index c09b1fae8fc9..e38546c66f3e 100644 --- a/test/run_suite.py +++ b/test/run_suite.py @@ -23,7 +23,7 @@ # Per-commit test suites (run on every PR) PER_COMMIT_SUITES = { - HWBackend.CPU: ["stage-a-test-cpu"], + HWBackend.CPU: ["stage-a-test-cpu", "stage-b-test-cpu"], HWBackend.AMD: [ "stage-a-test-1-gpu-small-amd", "stage-b-test-1-gpu-small-amd", @@ -238,7 +238,9 @@ def run_a_suite(args): for f in glob.glob( os.path.join(script_dir, "registered", "**", "*.py"), recursive=True ) - if not f.endswith("/conftest.py") and not f.endswith("/__init__.py") + if not f.endswith("/conftest.py") + and not f.endswith("/__init__.py") + and not f.endswith("/cpu/utils.py") ] # JIT kernel tests and benchmarks (live alongside kernel source)