From fbeed5262dfb65652502cc1148a9617b5ffbda84 Mon Sep 17 00:00:00 2001 From: Ma Mingfei Date: Tue, 12 May 2026 13:32:10 +0800 Subject: [PATCH] Revert "Migrate Intel CPU cases to the test/registered. (#22670)" This reverts commit ecf5d844f5bfd8d243043da366fb8b58f6911ef3. --- .github/workflows/pr-test-xeon.yml | 2 +- .pre-commit-config.yaml | 1 - docs_new/index.mdx | 60 +-- scripts/ci/check_registered_tests.py | 4 +- test/registered/cpu/test_activation.py | 59 --- test/registered/cpu/test_binding.py | 30 -- test/registered/cpu/test_bmm.py | 98 ---- test/registered/cpu/test_causal_conv1d.py | 330 ------------- test/registered/cpu/test_cpu_graph.py | 91 ---- test/registered/cpu/test_decode.py | 172 ------- test/registered/cpu/test_extend.py | 225 --------- test/registered/cpu/test_flash_attn.py | 245 ---------- test/registered/cpu/test_gemm.py | 335 ------------- .../cpu/test_intel_amx_attention_backend_a.py | 76 --- .../cpu/test_intel_amx_attention_backend_b.py | 38 -- .../cpu/test_intel_amx_attention_backend_c.py | 56 --- test/registered/cpu/test_mamba.py | 397 ---------------- test/registered/cpu/test_mla.py | 158 ------- test/registered/cpu/test_moe.py | 355 -------------- test/registered/cpu/test_norm.py | 435 ----------------- .../registered/cpu/test_qkv_proj_with_rope.py | 443 ------------------ test/registered/cpu/test_qwen3.py | 152 ------ test/registered/cpu/test_rope.py | 285 ----------- .../cpu/test_server_args_backend.py | 38 -- test/registered/cpu/test_shared_expert.py | 235 ---------- test/registered/cpu/test_topk.py | 223 --------- test/registered/cpu/utils.py | 440 ----------------- test/run_suite.py | 6 +- 28 files changed, 35 insertions(+), 4954 deletions(-) delete mode 100644 test/registered/cpu/test_activation.py delete mode 100644 test/registered/cpu/test_binding.py delete mode 100644 test/registered/cpu/test_bmm.py delete mode 100644 test/registered/cpu/test_causal_conv1d.py delete mode 100644 test/registered/cpu/test_cpu_graph.py delete mode 100644 test/registered/cpu/test_decode.py delete mode 100644 test/registered/cpu/test_extend.py delete mode 100644 test/registered/cpu/test_flash_attn.py delete mode 100644 test/registered/cpu/test_gemm.py delete mode 100644 test/registered/cpu/test_intel_amx_attention_backend_a.py delete mode 100644 test/registered/cpu/test_intel_amx_attention_backend_b.py delete mode 100644 test/registered/cpu/test_intel_amx_attention_backend_c.py delete mode 100644 test/registered/cpu/test_mamba.py delete mode 100644 test/registered/cpu/test_mla.py delete mode 100644 test/registered/cpu/test_moe.py delete mode 100644 test/registered/cpu/test_norm.py delete mode 100644 test/registered/cpu/test_qkv_proj_with_rope.py delete mode 100644 test/registered/cpu/test_qwen3.py delete mode 100644 test/registered/cpu/test_rope.py delete mode 100644 test/registered/cpu/test_server_args_backend.py delete mode 100644 test/registered/cpu/test_shared_expert.py delete mode 100644 test/registered/cpu/test_topk.py delete mode 100644 test/registered/cpu/utils.py diff --git a/.github/workflows/pr-test-xeon.yml b/.github/workflows/pr-test-xeon.yml index abaa9401d96d..0fb4721ba173 100644 --- a/.github/workflows/pr-test-xeon.yml +++ b/.github/workflows/pr-test-xeon.yml @@ -115,7 +115,7 @@ jobs: timeout-minutes: 36 run: | docker exec -w /sglang-checkout/ ci_sglang_xeon \ - bash -c "source /opt/.venv/bin/activate && cd ./test && python3 run_suite.py --hw cpu --suite stage-b-test-cpu" + bash -c "source /opt/.venv/bin/activate && cd ./test/srt && python3 run_suite.py --suite per-commit-cpu --timeout-per-file 1500" - name: Change permission timeout-minutes: 2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7fea2c91d535..8118e91c26cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -92,7 +92,6 @@ repos: entry: python3 scripts/ci/check_registered_tests.py language: system files: ^test/registered/.*\.py$ - exclude: ^test/registered/.*/utils\.py$ pass_filenames: false - id: check-no-docs-changes name: reject changes under legacy docs/ diff --git a/docs_new/index.mdx b/docs_new/index.mdx index 908d517386da..6a5b1ed19ddb 100644 --- a/docs_new/index.mdx +++ b/docs_new/index.mdx @@ -83,7 +83,7 @@ It is designed to deliver low-latency and high-throughput inference across a wid }} > Updating 1T parameters in seconds \u2014 P2P weight transfer in Large Scale Distributed RL - {"Updating 1T parameters in seconds \u2014 P2P weight transfer in Large Scale Distributed RL"} + {"DeepSeek-V4 on Day 0: From Fast Inference to Verified RL with SGLang and Miles"}

- {"April 29, 2026"} + {"April 25, 2026"}

DeepSeek-V4 on Day 0: From Fast Inference to Verified RL with SGLang and Miles - {"DeepSeek-V4 on Day 0: From Fast Inference to Verified RL with SGLang and Miles"} + {"HiSparse: Turbocharging Sparse Attention with Hierarchical Memory"}

- {"April 25, 2026"} + {"April 10, 2026"}

HiSparse: Turbocharging Sparse Attention with Hierarchical Memory - {"HiSparse: Turbocharging Sparse Attention with Hierarchical Memory"} + {"Highlights of SGLang at NVIDIA GTC 2026"}

- {"April 10, 2026"} + {"March 31, 2026"}

Highlights of SGLang at NVIDIA GTC 2026 - {"Highlights of SGLang at NVIDIA GTC 2026"} + {"Elastic EP in SGLang: Achieving Partial Failure Tolerance for DeepSeek MoE Deployments"}

- {"March 31, 2026"} + {"March 25, 2026"}

Elastic EP in SGLang: Achieving Partial Failure Tolerance for DeepSeek MoE Deployments - {"Elastic EP in SGLang: Achieving Partial Failure Tolerance for DeepSeek MoE Deployments"} + {"ROCm Support for Miles: Large-Scale RL Post-Training on AMD Instinct\u2122 GPUs"}

- {"March 25, 2026"} + {"March 17, 2026"}

ROCm Support for Miles: Large-Scale RL Post-Training on AMD Instinct\u2122 GPUs - {"ROCm Support for Miles: Large-Scale RL Post-Training on AMD Instinct\u2122 GPUs"} + {"SGLang Adds Day-0 Support for NVIDIA Nemotron 3 Super for building High-Efficiency Multi-Agent Systems"}

- {"March 17, 2026"} + {"March 11, 2026"}

diff --git a/scripts/ci/check_registered_tests.py b/scripts/ci/check_registered_tests.py index 66def25ff4d2..3a9e9b87b242 100755 --- a/scripts/ci/check_registered_tests.py +++ b/scripts/ci/check_registered_tests.py @@ -22,11 +22,11 @@ def main() -> int: ci_register = importlib.util.module_from_spec(spec) spec.loader.exec_module(ci_register) - # Same filter as run_suite.py: skip conftest.py, __init__.py, and utils.py + # Same filter as run_suite.py: skip conftest.py and __init__.py files = sorted( f for f in glob.glob("test/registered/**/*.py", recursive=True) - if os.path.basename(f) not in ("conftest.py", "__init__.py", "utils.py") + if os.path.basename(f) not in ("conftest.py", "__init__.py") ) if not files: return 0 diff --git a/test/registered/cpu/test_activation.py b/test/registered/cpu/test_activation.py deleted file mode 100644 index fe020c8872f4..000000000000 --- a/test/registered/cpu/test_activation.py +++ /dev/null @@ -1,59 +0,0 @@ -import itertools -import unittest - -import torch -from utils import GeluAndMul, SiluAndMul, precision - -from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestActivation(CustomTestCase): - M = [128, 129, 257] - N = [22016, 22018] - dtype = [torch.float16, torch.bfloat16] - - def _silu_and_mul_test(self, m, n, dtype): - set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) - - x = torch.randn([m, n], dtype=dtype) - - out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) - ref_out = SiluAndMul(x) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - def _gelu_and_mul_test(self, m, n, dtype): - x = torch.randn([m, n], dtype=dtype) - - out = torch.ops.sgl_kernel.gelu_and_mul_cpu(x) - ref_out = GeluAndMul(x, approximate="none") - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - def _gelu_tanh_and_mul_test(self, m, n, dtype): - x = torch.randn([m, n], dtype=dtype) - - out = torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x) - ref_out = GeluAndMul(x, approximate="tanh") - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - def test_activation(self): - for params in itertools.product(self.M, self.N, self.dtype): - with self.subTest(m=params[0], n=params[1], dtype=params[2]): - self._silu_and_mul_test(*params) - self._gelu_and_mul_test(*params) - self._gelu_tanh_and_mul_test(*params) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_binding.py b/test/registered/cpu/test_binding.py deleted file mode 100644 index f623045a74df..000000000000 --- a/test/registered/cpu/test_binding.py +++ /dev/null @@ -1,30 +0,0 @@ -import re -import unittest - -import torch - -kernel = torch.ops.sgl_kernel - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -class TestBinding(CustomTestCase): - def test_binding(self): - start_id = 1 - n_cpu = 6 - - expected_cores = list(map(str, range(start_id, start_id + n_cpu))) - cpu_ids = ",".join(expected_cores) - output = kernel.init_cpu_threads_env(cpu_ids) - - bindings = re.findall(r"OMP tid: \d+, core (\d+)", output) - self.assertEqual(len(bindings), n_cpu) - - self.assertEqual(bindings, expected_cores) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_bmm.py b/test/registered/cpu/test_bmm.py deleted file mode 100644 index 93257b7182e7..000000000000 --- a/test/registered/cpu/test_bmm.py +++ /dev/null @@ -1,98 +0,0 @@ -import itertools -import unittest - -# TODO: use interface in cpu.py -import torch -import torch.nn as nn -from utils import precision - -from sglang.srt.layers.quantization.fp8_utils import input_to_float8 -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class Mod(nn.Module): - def __init__(self, input_channel, output_channel, has_bias): - super(Mod, self).__init__() - self.linear = torch.nn.Linear(input_channel, output_channel, has_bias) - - def forward(self, x): - return self.linear(x) - - -class TestBmm(CustomTestCase): - M = [1, 2, 11, 111] - N = [128 + 32, 512] - K = [512 + 32, 128 + 32] - B = [1, 16, 17] - chunk = [True, False] - - def _get_bmm_inputs(self, B, M, N, K, chunk, dtype): - if chunk: - mat1 = ( - torch.randn(M, B, K + 64, dtype=dtype).narrow(2, 0, K).transpose_(0, 1) - ) - mat2 = torch.randn(B, N, K, dtype=dtype).transpose_(1, 2) - mat3 = ( - torch.randn(M, B, N + 64, dtype=dtype).narrow(2, 0, N).transpose_(0, 1) - ) - else: - mat1 = torch.randn(M, B, K, dtype=dtype).transpose_(0, 1) - mat2 = torch.randn(B, N, K, dtype=dtype).transpose_(1, 2) - mat3 = torch.randn(M, B, N, dtype=dtype).transpose_(0, 1) - return mat1, mat2, mat3 - - def _bf16_bmm(self, B, M, N, K, chunk, dtype=torch.bfloat16): - mat1, mat2, mat3 = self._get_bmm_inputs(B, M, N, K, chunk, dtype) - ref = torch.bmm(mat1, mat2) - mat2_t = mat2.transpose_(1, 2) - mat3.zero_() - torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, mat2, False, None) - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) - - packed_B = torch.ops.sgl_kernel.convert_weight_packed(mat2_t) - mat3.zero_() - torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, packed_B, True, None) - torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) - - def _fp8_bmm(self, B, M, N, K, chunk, dtype=torch.bfloat16): - mat1, mat2, mat3 = self._get_bmm_inputs(B, M, N, K, chunk, dtype) - mat2_q, mat2_s = input_to_float8(mat2) - ref = torch.bmm(mat1, mat2_q.to(torch.bfloat16)) * mat2_s - mat2_q_t = mat2_q.transpose_(1, 2).contiguous() - mat3.zero_() - atol = rtol = precision[ref.dtype] - torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, mat2_q_t, False, mat2_s) - torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) - - packed_B_q = torch.ops.sgl_kernel.convert_weight_packed(mat2_q_t) - mat3.zero_() - torch.ops.sgl_kernel.bmm_cpu(mat3, mat1, packed_B_q, True, mat2_s) - torch.testing.assert_close(ref, mat3, atol=atol, rtol=rtol) - - def test_bmm(self): - for params in itertools.product( - self.B, - self.M, - self.N, - self.K, - self.chunk, - ): - with self.subTest( - B=params[0], - M=params[1], - N=params[2], - K=params[3], - chunk=params[4], - ): - self._bf16_bmm(*params) - self._fp8_bmm(*params) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_causal_conv1d.py b/test/registered/cpu/test_causal_conv1d.py deleted file mode 100644 index f588eafa2bdc..000000000000 --- a/test/registered/cpu/test_causal_conv1d.py +++ /dev/null @@ -1,330 +0,0 @@ -import unittest -from typing import Optional - -import sgl_kernel # noqa: F401 -import torch -import torch.nn.functional as F -from utils import parametrize, precision - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -causal_conv1d_weight_pack = torch.ops.sgl_kernel.causal_conv1d_weight_pack -causal_conv1d_fwd = torch.ops.sgl_kernel.causal_conv1d_fwd_cpu -causal_conv1d_update = torch.ops.sgl_kernel.causal_conv1d_update_cpu - - -torch.manual_seed(1234) - -PAD_SLOT_ID = -1 - - -def causal_conv1d_ref( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return (out, None) if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update_ref( - x, conv_state, weight, bias=None, activation=None, cache_seqlens=None -): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the - conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - - x_new = torch.cat([conv_state, x], dim=-1) - conv_state.copy_(x_new[:, :, -state_len:]) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ - :, :, -seqlen: - ] - - out = out.squeeze(-1) - return out if activation is None else F.silu(out) - - -class TestCausalConv1d(CustomTestCase): - activation = "silu" - - @parametrize( - batch=[1, 1024], - dim=[96, 512], - seqlen=[2, 36], - width=[4], - has_bias=[True, False], - has_initial_state=[True, False], - ) - def test_causal_conv1d( - self, - batch, - dim, - seqlen, - width, - has_bias, - has_initial_state, - dtype=torch.bfloat16, - prepack=True, - ): - x = torch.randn(batch, seqlen, dim).to(dtype).transpose_(-1, -2) - weight = torch.randn(dim, width).to(dtype) - bias = torch.randn(dim).to(dtype) if has_bias else None - - if has_initial_state: - initial_states = torch.randn(batch, dim, width - 1, dtype=dtype) - has_initial_state_tensor = torch.ones(batch, dtype=torch.bool) - else: - initial_states = None - has_initial_state_tensor = None - - packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight - - out_ref, final_states_ref = causal_conv1d_ref( - x, - weight, - bias, - initial_states, - return_final_states=has_initial_state, - activation=self.activation, - ) - - out = causal_conv1d_fwd( - x, - packed_weight, - bias, - initial_states, - None, - None, - has_initial_state_tensor, - self.activation in ["silu"], - PAD_SLOT_ID, - prepack, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) - torch.testing.assert_close( - final_states_ref, initial_states, atol=atol, rtol=rtol - ) - - @parametrize( - batch=[11], - dim=[96], - max_seqlen=[66], - width=[4], - ) - def test_causal_conv1d_varlen( - self, - batch, - dim, - max_seqlen, - width, - has_bias=False, - dtype=torch.bfloat16, - prepack=False, - ): - total_entries = batch + 3 - - seqlens = torch.randint(1, max_seqlen, (batch + 1,)) - seqlens[0] = 0 - # 1 or 2 must test - seqlens[-2] = 2 - - query_start_loc = torch.cumsum(seqlens, dim=0).to(torch.int32) - - seqlen = query_start_loc[-1].item() - x = torch.randn(seqlen, dim, dtype=dtype).transpose_(-1, -2) - weight = torch.randn(dim, width, dtype=dtype) - bias = torch.randn(dim, dtype=dtype) if has_bias else None - - final_states = torch.randn(total_entries, dim, width - 1, dtype=dtype) - final_states_ref = final_states.clone() - - has_initial_states = torch.randint(0, 2, (batch,), dtype=torch.bool).fill_( - False - ) - state_indices = torch.randperm(total_entries, dtype=torch.int32)[:batch] - - out_ref = [] - out_ref_b = [] - - return_final_states = final_states is not None - splits = torch.split(x, seqlens[1:].tolist(), dim=1) - for i, x_s in enumerate(splits): - out_ref_b.append( - causal_conv1d_ref( - x_s.unsqueeze(0), - weight, - bias, - activation=self.activation, - return_final_states=return_final_states, - final_states_out=( - final_states_ref[state_indices[i]].unsqueeze(0) - if return_final_states - else None - ), - initial_states=( - final_states_ref[state_indices[i]].unsqueeze(0) - if has_initial_states[i] - else None - ), - ) - ) - out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) - out_ref_tensor = torch.cat(out_ref, dim=0).squeeze(0) - - out = causal_conv1d_fwd( - x, - weight, - bias, - final_states, - query_start_loc, - state_indices, - has_initial_states, - self.activation in ["silu"], - PAD_SLOT_ID, - prepack, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref_tensor, out, atol=atol, rtol=rtol) - torch.testing.assert_close(final_states_ref, final_states, atol=atol, rtol=rtol) - - @parametrize( - batch=[11], - dim=[32, 64, 96], - width=[4], - ) - def test_causal_conv1d_update( - self, batch, dim, width, has_bias=False, dtype=torch.bfloat16, prepack=True - ): - x = torch.randn(batch, dim).to(dtype) - conv_state = torch.randn(batch, dim, width - 1, dtype=dtype) - weight = torch.randn(dim, width).to(dtype) - bias = torch.randn(dim).to(dtype) if has_bias else None - - packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight - - conv_state_ref = conv_state.clone() - out_ref = causal_conv1d_update_ref( - x, conv_state_ref, weight, bias, activation=self.activation - ) - - cache_seqlens = None - conv_state_indices = None - out = causal_conv1d_update( - x, - conv_state, - packed_weight, - bias, - self.activation in ["silu"], - cache_seqlens, - conv_state_indices, - PAD_SLOT_ID, - prepack, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) - torch.testing.assert_close(conv_state_ref, conv_state, atol=atol, rtol=rtol) - - @parametrize( - batch=[7], - dim=[96], - width=[4], - ) - def test_causal_conv1d_update_with_batch_gather( - self, batch, dim, width, has_bias=False, dtype=torch.bfloat16, prepack=True - ): - total_entries = batch + 3 - - x = torch.randn(batch, dim).to(dtype=dtype) - - conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32) - conv_state = torch.randn(total_entries, dim, width - 1, dtype=dtype) - - weight = torch.randn(dim, width).to(dtype=dtype) - bias = torch.randn(dim).to(dtype=dtype) if has_bias else None - conv_state_ref = conv_state[conv_state_indices, :] - - packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight - - out_ref = causal_conv1d_update_ref( - x, conv_state_ref, weight, bias, activation=self.activation - ) - - cache_seqlens = None - out = causal_conv1d_update( - x, - conv_state, - packed_weight, - bias, - self.activation in ["silu"], - cache_seqlens, - conv_state_indices, - PAD_SLOT_ID, - prepack, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) - torch.testing.assert_close( - conv_state_ref, conv_state[conv_state_indices, :], atol=atol, rtol=rtol - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_cpu_graph.py b/test/registered/cpu/test_cpu_graph.py deleted file mode 100644 index 37d70df72b24..000000000000 --- a/test/registered/cpu/test_cpu_graph.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -Usage: -python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu -""" - -import copy -import os -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - intel_amx_benchmark, - is_in_ci, - popen_launch_server, -) - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -class TestCPUGraph(CustomTestCase): - - @intel_amx_benchmark( - extra_args=[ - "--batch-size", - "1", - "--mem-fraction-static", - "0.05", - "--enable-torch-compile", - "--torch-compile-max-bs", - "2", - "--cuda-graph-bs", - "2", - ], - min_throughput=7, - ) - def test_latency_torch_compile_cpu(self): - return DEFAULT_MLA_MODEL_NAME_FOR_TEST - - def test_mmlu_torch_compile_cpu(self): - model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_TEST - cpu_ids_by_node = get_cpu_ids_by_node() - n_numa_node = len(cpu_ids_by_node) - env = copy.deepcopy(os.environ) - env["SGLANG_CPU_OMP_THREADS_BIND"] = "all" - process = popen_launch_server( - model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--attention-backend", - "intel_amx", - "--mem-fraction-static", - "0.05", - "--disable-radix", - "--trust-remote-code", - "--disable-overlap-schedule", - "--enable-torch-compile", - "--cuda-graph-bs", - "2", - "--tp", - f"{n_numa_node}", - ], - env=env, - ) - - try: - args = SimpleNamespace( - base_url=base_url, - model=model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - if is_in_ci(): - self.assertGreater(metrics["score"], 0.45) - finally: - kill_process_tree(process.pid) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_decode.py b/test/registered/cpu/test_decode.py deleted file mode 100644 index 316c446bdab7..000000000000 --- a/test/registered/cpu/test_decode.py +++ /dev/null @@ -1,172 +0,0 @@ -import unittest - -import torch -from torch.nn.functional import scaled_dot_product_attention - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestDecodeAttention(CustomTestCase): - def _run_sdpa_forward_decode( - self, - query: torch.Tensor, - output: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - req_to_token: torch.Tensor, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - scaling=None, - enable_gqa=False, - causal=False, - ): - # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] - query = query.movedim(0, query.dim() - 2) - - start_q, start_kv = 0, 0 - for seq_idx in range(seq_lens.shape[0]): - seq_len_q = 1 - seq_len_kv = seq_lens[seq_idx] - end_q = start_q + seq_len_q - end_kv = start_kv + seq_len_kv - - per_req_query = query[:, start_q:end_q, :] - - # get key and value from cache. per_req_tokens contains the kv cache - # index for each token in the sequence. - req_pool_idx = req_pool_indices[seq_idx] - per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] - per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) - per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) - - per_req_out = ( - scaled_dot_product_attention( - per_req_query.unsqueeze(0), - per_req_key.unsqueeze(0), - per_req_value.unsqueeze(0), - enable_gqa=enable_gqa, - scale=scaling, - is_causal=causal, - ) - .squeeze(0) - .movedim(query.dim() - 2, 0) - ) - output[start_q:end_q, :, :] = per_req_out - start_q, start_kv = end_q, end_kv - - return output - - def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device): - # This represents the number of tokens already in the sequence - seq_len = 1024 - total_tokens = B * seq_len - sm_scale = 1.0 / (D**0.5) - logit_cap = 0.0 - num_kv_splits = 8 - enable_gqa = H_Q != H_KV - - # q represents the new token being generated, one per batch - q = torch.randn(B, H_Q, D, dtype=dtype, device=device) - - # k_buffer and v_buffer represent all previous tokens - k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) - v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device) - - key = torch.randn(B, H_KV, D, dtype=dtype) - value = torch.randn(B, H_KV, D_V, dtype=dtype) - loc = torch.randint(0, 10, (B,)).to(torch.int64) - - # set kv cache - k_buffer[loc] = key - v_buffer[loc] = value - - # o will have the same shape as q - o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) - o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) - - req_to_token = ( - torch.arange(total_tokens, device=device) - .reshape(B, seq_len) - .to(torch.int32) - ) - b_req_idx = torch.arange(B, device=device).to(torch.int64) - b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64) - - attn_logits = torch.empty( - (B, H_Q, num_kv_splits, D_V + 1), - dtype=torch.float32, - device=device, - ) - - # k_buffer, v_buffer, query, key and value supports non-contiguous tensors - k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) - v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) - q = q.transpose(0, 1).contiguous().transpose(0, 1) - key = key.transpose(0, 1).contiguous().transpose(0, 1) - value = value.transpose(0, 1).contiguous().transpose(0, 1) - torch.ops.sgl_kernel.decode_attention_cpu( - q, - k_buffer, - v_buffer, - o, - key, - value, - loc, - attn_logits, - req_to_token, - b_req_idx, - b_seq_len, - sm_scale, - logit_cap, - ) - - self._run_sdpa_forward_decode( - q, - o_grouped, - k_buffer, - v_buffer, - req_to_token, - b_req_idx, - b_seq_len, - scaling=sm_scale, - enable_gqa=enable_gqa, - ) - - cos_sim = torch.nn.functional.cosine_similarity( - o.flatten(), o_grouped.flatten(), dim=0 - ) - self.assertGreater(cos_sim.item(), 0.99) - torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6) - - def _test_grouped_decode_attention(self, device="cpu"): - configs = [ - (2, 16, 16, 64, 64), - (2, 16, 1, 16, 16), - (2, 32, 8, 33, 55), - (2, 16, 1, 64, 64), - (2, 64, 1, 13, 13), - (2, 128, 1, 80, 80), - (2, 128, 2, 512, 512), - (1, 16, 1, 576, 512), - (1, 16, 16, 576, 512), - (1, 22, 1, 576, 512), - (1, 40, 8, 128, 128), - ] - - for B, H_Q, H_KV, D, D_V in configs: - for dtype in [torch.bfloat16, torch.float16]: - self._test_grouped_decode_attention_once( - B, H_Q, H_KV, D, D_V, dtype=dtype, device=device - ) - - def test_grouped_decode_attention(self): - self._test_grouped_decode_attention("cpu") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_extend.py b/test/registered/cpu/test_extend.py deleted file mode 100644 index 573531356761..000000000000 --- a/test/registered/cpu/test_extend.py +++ /dev/null @@ -1,225 +0,0 @@ -import unittest - -import torch -from torch.nn.functional import scaled_dot_product_attention - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestExtendAttention(CustomTestCase): - - def _run_sdpa_forward_extend( - self, - query: torch.Tensor, - output: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - req_to_token: torch.Tensor, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - extend_prefix_lens: torch.Tensor, - extend_seq_lens: torch.Tensor, - scaling=None, - enable_gqa=False, - causal=False, - ): - - assert seq_lens.shape[0] == extend_prefix_lens.shape[0] - assert seq_lens.shape[0] == extend_seq_lens.shape[0] - - # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] - query = query.movedim(0, query.dim() - 2) - - start_q, start_kv = 0, 0 - for seq_idx in range(seq_lens.shape[0]): - - extend_seq_len_q = extend_seq_lens[seq_idx] - prefill_seq_len_q = extend_prefix_lens[seq_idx] - - seq_len_kv = seq_lens[seq_idx] - end_q = start_q + extend_seq_len_q - end_kv = start_kv + seq_len_kv - - per_req_query = query[:, start_q:end_q, :] - per_req_query_redudant = torch.empty( - (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]), - dtype=per_req_query.dtype, - device=per_req_query.device, - ) - - per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query - - # get key and value from cache. per_req_tokens contains the kv cache - # index for each token in the sequence. - req_pool_idx = req_pool_indices[seq_idx] - per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] - per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) - per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) - - per_req_out_redudant = ( - scaled_dot_product_attention( - per_req_query_redudant.unsqueeze(0), - per_req_key.unsqueeze(0), - per_req_value.unsqueeze(0), - enable_gqa=enable_gqa, - scale=scaling, - is_causal=causal, - ) - .squeeze(0) - .movedim(query.dim() - 2, 0) - ) - output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :] - start_q, start_kv = end_q, end_kv - return output - - def _test_extend_attention_once( - self, - B, - N_CTX, - H_Q, - H_KV, - D, - DV, - mla=False, - *, - b_seq_len_prefix=None, - b_seq_len_extend=None, - ): - dtype = torch.bfloat16 - - if b_seq_len_prefix is None: - b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) - if mla: - b_seq_len_prefix.zero_() - else: - b_seq_len_prefix = torch.as_tensor(b_seq_len_prefix, dtype=torch.int32) - - if b_seq_len_extend is None: - b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) - else: - b_seq_len_extend = torch.as_tensor(b_seq_len_extend, dtype=torch.int32) - - b_seq_len = b_seq_len_prefix + b_seq_len_extend - max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - - b_req_idx = torch.arange(B, dtype=torch.int32) - req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32) - b_start_loc = torch.zeros((B,), dtype=torch.int32) - b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) - b_start_loc_extend = torch.zeros((B,), dtype=torch.int32) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - - for i in range(B): - req_to_tokens[i, : b_seq_len[i]] = torch.arange( - b_start_loc[i], b_start_loc[i] + b_seq_len[i] - ) - - total_token_num = torch.sum(b_seq_len).item() - extend_token_num = torch.sum(b_seq_len_extend).item() - - H_BUF = 1 if mla else H_KV - k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype) - v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype) - - k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype) - v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype) - q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype) - - for i in range(B): - extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] - extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] - extend_start = b_start_loc_extend[i] - extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] - k_extend[extend_start:extend_end] = k_buffer[ - extend_start_in_buffer:extend_end_in_buffer - ] - v_extend[extend_start:extend_end] = v_buffer[ - extend_start_in_buffer:extend_end_in_buffer - ] - q_extend[extend_start:extend_end] = ( - torch.randn((b_seq_len_extend[i], H_Q, D), dtype=dtype) * 20 - ) - - # q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors - q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1) - k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) - v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) - k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) - v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) - - b_seq_len_extend = b_seq_len - b_seq_len_prefix - b_start_loc_extend = torch.zeros_like(b_seq_len) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() - - sm_scale = 1.0 / (D**0.5) - logit_cap = 0.0 - - # handle index type - b_req_idx = b_req_idx.to(torch.int64) - b_seq_len = b_seq_len.to(torch.int64) - - enable_gqa = H_Q != H_KV - o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) - self._run_sdpa_forward_extend( - q_extend, - o_ref, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_prefix, - b_seq_len_extend, - scaling=sm_scale, - enable_gqa=enable_gqa, - causal=True, - ) - - o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) - torch.ops.sgl_kernel.extend_attention_cpu( - q_extend, - k_extend, - v_extend, - o_extend, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, - max_len_extend, - sm_scale, - logit_cap, - ) - - torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2) - - def test_extend_attention(self): - for is_mla in [True, False]: - self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla) - self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla) - self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla) - self._test_extend_attention_once(1, 9000, 16, 1, 32, 32, is_mla) - - def test_extend_attention_large_seq_causal_mask(self): - self._test_extend_attention_once( - B=1, - N_CTX=5001, - H_Q=8, - H_KV=2, - D=64, - DV=64, - b_seq_len_prefix=[0], - b_seq_len_extend=[5000], - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_flash_attn.py b/test/registered/cpu/test_flash_attn.py deleted file mode 100644 index f7f47c454750..000000000000 --- a/test/registered/cpu/test_flash_attn.py +++ /dev/null @@ -1,245 +0,0 @@ -import unittest - -import sgl_kernel # noqa: F401 -import torch -import torch.nn.functional as F -from utils import parametrize, precision - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -flash_attn_varlen_func = torch.ops.sgl_kernel.flash_attn_varlen_func - -torch.manual_seed(1234) - - -def flash_attn_varlen_ref( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - is_causal, - enable_gqa, -): - cu_q = cu_seqlens_q.tolist() - cu_k = cu_seqlens_k.tolist() - batch = len(cu_k) - 1 - - # [T, H, D] -> [1, H, T, D] - q, k, v = [x.unsqueeze(0).transpose(1, 2) for x in [q, k, v]] - - B, H, T, D = q.shape - out = torch.empty(B, H, T, v.size(-1), dtype=q.dtype) - for b in range(batch): - start_q, end_q = cu_q[b], cu_q[b + 1] - start_k, end_k = cu_k[b], cu_k[b + 1] - - out[:, :, start_q:end_q, :] = F.scaled_dot_product_attention( - q[:, :, start_q:end_q, :], - k[:, :, start_k:end_k, :], - v[:, :, start_k:end_k, :], - is_causal=is_causal, - enable_gqa=enable_gqa, - ) - - # [1, H, T, D] -> [T, H, D] - return out.transpose(1, 2).squeeze(0) - - -# faster version ref kernel for non varlen case -def flash_attn_non_varlen_ref( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - is_causal, - enable_gqa, -): - cu_q = cu_seqlens_q.tolist() - cu_k = cu_seqlens_k.tolist() - batch = len(cu_k) - 1 - - B_T, H, D = q.shape - T = B_T // batch - - # [T, H, D] -> [1, H, T, D] - q, k, v = [x.reshape(batch, T, H, D).transpose(1, 2) for x in [q, k, v]] - - out = F.scaled_dot_product_attention( - q, - k, - v, - is_causal=is_causal, - enable_gqa=enable_gqa, - ) - # [B, H, T, D] -> [B * T, H, D] - return out.transpose(1, 2).reshape(batch * T, H, D) - - -class TestFlashAttn(CustomTestCase): - - @parametrize( - batch=[4], - max_seqlen_q=[35, 96], - max_seqlen_k=[35, 96], - num_heads=[16], - num_heads_kv=[16, 2], - head_dim=[32, 48], # test when D is not 32x - head_dim_v=[32], - is_causal=[True, False], - ) - def test_flash_attn_varlen( - self, - batch, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_kv, - head_dim, - head_dim_v, - is_causal, - ): - dtype = torch.bfloat16 - - # random seqlens for k and kv - seqlens_q = torch.randint(1, max_seqlen_q, (batch,), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlen_k, (batch,), dtype=torch.int32) - cu_seqlens_q = torch.zeros((batch + 1,), dtype=torch.int32) - cu_seqlens_k = torch.zeros((batch + 1,), dtype=torch.int32) - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, 0) - cu_seqlens_k[1:] = torch.cumsum(seqlens_k, 0) - - sum_seqlen_q = seqlens_q.sum().item() - sum_seqlen_k = seqlens_k.sum().item() - q = torch.randn(sum_seqlen_q, num_heads, head_dim).to(dtype) - k = torch.randn(sum_seqlen_k, num_heads_kv, head_dim).to(dtype) - v = torch.randn(sum_seqlen_k, num_heads_kv, head_dim_v).to(dtype) - - out_ref = flash_attn_varlen_ref( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - is_causal=is_causal, - enable_gqa=num_heads != num_heads_kv, - ) - - out = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqlens_q.max().item(), - seqlens_k.max().item(), - is_causal, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) - - # test with large size to capture overflow issue - @parametrize( - batch=[4097], - max_seqlen_q=[4097], - max_seqlen_k=[4097], - num_heads=[4], - num_heads_kv=[4], - head_dim=[32], - head_dim_v=[32], - is_causal=[False], - ) - def test_flash_attn_large_size( - self, - batch, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_kv, - head_dim, - head_dim_v, - is_causal, - ): - dtype = torch.bfloat16 - - # test the non varlen case - seqlens_q = torch.full((batch,), max_seqlen_q, dtype=torch.int32) - seqlens_k = torch.full((batch,), max_seqlen_k, dtype=torch.int32) - - cu_seqlens_q = torch.zeros((batch + 1,), dtype=torch.int32) - cu_seqlens_k = torch.zeros((batch + 1,), dtype=torch.int32) - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, 0) - cu_seqlens_k[1:] = torch.cumsum(seqlens_k, 0) - - sum_seqlen_q = seqlens_q.sum().item() - sum_seqlen_k = seqlens_k.sum().item() - q = torch.randn(sum_seqlen_q, num_heads, head_dim).to(dtype) - k = torch.randn(sum_seqlen_k, num_heads_kv, head_dim).to(dtype) - v = torch.randn(sum_seqlen_k, num_heads_kv, head_dim_v).to(dtype) - - out_ref = flash_attn_non_varlen_ref( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - is_causal=is_causal, - enable_gqa=num_heads != num_heads_kv, - ) - - out = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqlens_q.max().item(), - seqlens_k.max().item(), - is_causal, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) - - def _test_flash_attn_large_seq_causal_mask_once(self, seqlens): - dtype = torch.bfloat16 - num_heads = 8 - num_heads_kv = 2 - head_dim = 64 - - seqlens_t = torch.tensor(seqlens, dtype=torch.int32) - cu_seqlens = torch.zeros(len(seqlens) + 1, dtype=torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens_t, 0) - total = cu_seqlens[-1].item() - max_seqlen = seqlens_t.max().item() - - q = torch.randn(total, num_heads, head_dim, dtype=dtype) - k = torch.randn(total, num_heads_kv, head_dim, dtype=dtype) - v = torch.randn(total, num_heads_kv, head_dim, dtype=dtype) - - out_ref = flash_attn_varlen_ref( - q, k, v, cu_seqlens, cu_seqlens, is_causal=True, enable_gqa=True - ) - out = flash_attn_varlen_func( - q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, True - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) - - def test_flash_attn_large_seq_causal_mask(self): - # Non-varlen path: single sequence, has_varlen_sequences returns False - # → dispatches to flash_attn_kernel_impl. - self._test_flash_attn_large_seq_causal_mask_once([5000]) - # Varlen path: sequences with different lengths, has_varlen_sequences - # returns True → dispatches to flash_attn_varlen_kernel_impl - self._test_flash_attn_large_seq_causal_mask_once([5000, 4999]) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_gemm.py b/test/registered/cpu/test_gemm.py deleted file mode 100644 index c5ea056e7b8a..000000000000 --- a/test/registered/cpu/test_gemm.py +++ /dev/null @@ -1,335 +0,0 @@ -import itertools -import unittest - -# TODO: use interface in cpu.py -import torch -import torch.nn as nn -from utils import ( - convert_weight, - native_w8a8_per_token_matmul, - per_token_quant_int8, - precision, - unpack_and_dequant_awq, - unpack_and_dequant_gptq, -) - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class Mod(nn.Module): - def __init__(self, input_channel, output_channel, has_bias): - super(Mod, self).__init__() - self.linear = torch.nn.Linear(input_channel, output_channel, has_bias) - - def forward(self, x): - return self.linear(x) - - -class TestGemm(CustomTestCase): - M = [1, 101] - N = [16, 32 * 13] - K = [32 * 16] - has_bias = [False, True] - - M_int8 = [2, 128] - N_int8 = [32 * 12] - K_int8 = [32 * 17] - - M_fp8 = [1, 11] - N_fp8 = [128, 224] - K_fp8 = [512, 576] - - M_awq = [1, 32] - N_awq = [4096] - K_awq = [4096] - - M_gptq = [1, 32] - N_gptq = [4096] - K_gptq = [4096] - - def _bf16_gemm(self, M, N, K, has_bias): - - mat1 = torch.randn(M, K, dtype=torch.bfloat16) - mat2 = torch.randn(N, K, dtype=torch.bfloat16) - - ref = torch.matmul(mat1.float(), mat2.float().t()) - if has_bias: - bias = torch.randn(N, dtype=torch.float32) - ref.add_(bias.bfloat16()) - - ref = ref.bfloat16() - - out = torch.ops.sgl_kernel.weight_packed_linear( - mat1, mat2, bias if has_bias else None, False - ) - - packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2) - out2 = torch.ops.sgl_kernel.weight_packed_linear( - mat1, packed_mat2, bias if has_bias else None, True - ) - - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) - torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol) - - def test_bf16_gemm(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.has_bias, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - has_bias=params[3], - ): - self._bf16_gemm(*params) - - def _bf16_gemm_with_small_oc(self, M, N, K, has_bias, use_post_sigmul): - use_post_sigmul = use_post_sigmul and N == 1 - mat_mul = ( - None if not use_post_sigmul else torch.randn(M, 2 * K, dtype=torch.bfloat16) - ) - mat1 = torch.randn(M, K, dtype=torch.bfloat16) - mat2 = torch.randn(N, K, dtype=torch.bfloat16) - - ref = torch.nn.functional.linear(mat1, mat2) - if has_bias: - bias = torch.randn(N, dtype=torch.float32) - ref.add_(bias) - if use_post_sigmul: - ref = torch.nn.functional.sigmoid(ref) * mat_mul - out = torch.ops.sgl_kernel.fused_linear_sigmoid_mul( - mat1, - torch.ops.sgl_kernel.convert_weight_packed(mat2), - bias if has_bias else None, - True, - mat_mul if use_post_sigmul else None, - ) - else: - out = torch.ops.sgl_kernel.weight_packed_linear( - mat1, - torch.ops.sgl_kernel.convert_weight_packed(mat2), - bias if has_bias else None, - True, - ) - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) - - def test_bf16_gemm_with_small_oc(self): - for params in itertools.product( - [1, 8, 32, 1024], [12, 1], self.K, self.has_bias, [False, True] - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - has_bias=params[3], - use_post_sigmul=params[4], - ): - self._bf16_gemm_with_small_oc(*params) - - def _int8_gemm(self, M, N, K, has_bias): - dtype = torch.bfloat16 - A = torch.randn((M, K), dtype=dtype) / 10 - Aq, As = per_token_quant_int8(A) - - factor_for_scale = 1e-2 - int8_max = 127 - int8_min = -128 - - B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2 - Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) - Bs = torch.rand(N) * factor_for_scale - - bias = torch.randn(N) if has_bias else None - ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype) - - atol = rtol = precision[ref_out.dtype] - - Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A) - out = torch.ops.sgl_kernel.int8_scaled_mm_cpu( - Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False - ) - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - # test the fused version - fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant( - A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False - ) - torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol) - - def test_int8_gemm(self): - for params in itertools.product( - self.M_int8, - self.N_int8, - self.K_int8, - self.has_bias, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - has_bias=params[3], - ): - self._int8_gemm(*params) - - def _fp8_gemm(self, M, N, K, has_bias): - prepack = True - chunk = False - scale_block_size_N = 64 - scale_block_size_K = 128 - assert scale_block_size_N <= N - assert scale_block_size_K <= K - A_dtype = torch.bfloat16 - - model = Mod(K, N, has_bias).eval() - if chunk: - data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K) - else: - data = torch.randn(M, K, dtype=A_dtype) - - weight = model.linear.weight # (N, K) - - if has_bias: - bias = model.linear.bias - - fp8_weight, scales, dq_weight = convert_weight( - weight, [scale_block_size_N, scale_block_size_K], A_dtype - ) - - if has_bias: - ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype) - else: - ref = torch.matmul(data.to(A_dtype), dq_weight.T) - - if prepack: - fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight) - - opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu( - data, - fp8_weight, - scales, - [scale_block_size_N, scale_block_size_K], - bias if has_bias else None, - data.dtype, - prepack, - ) - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol) - - def test_fp8_gemm(self): - for params in itertools.product( - self.M_fp8, - self.N_fp8, - self.K_fp8, - self.has_bias, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - has_bias=params[3], - ): - self._fp8_gemm(*params) - - def _int4_awq_gemm(self, M, N, K, group_size, has_bias): - awq_weight = torch.randint(-128, 128, (K, N // 8)).to(torch.int) - awq_zero = torch.randint(0, 10, (K // group_size, N // 8)).to(torch.int) - awq_scales = torch.rand(int(K // group_size), N).to(torch.bfloat16) - bf16_weight, _ = unpack_and_dequant_awq( - awq_weight, awq_zero, awq_scales, 4, 128 - ) - if has_bias: - bias = torch.rand(bf16_weight.shape[0]).to(torch.float) - else: - bias = None - x = torch.rand(M, bf16_weight.size(-1)).to(torch.bfloat16) - ref_res = torch.nn.functional.linear( - x, bf16_weight, bias=bias.to(torch.bfloat16) if has_bias else None - ) - - packed_weight, packed_zero, packed_scales = ( - torch.ops.sgl_kernel.convert_weight_packed_scale_zp( - awq_weight, awq_zero, awq_scales, 0 - ) - ) - target_res = torch.ops.sgl_kernel.int4_scaled_mm_cpu( - x, - packed_weight, - packed_zero, - packed_scales, - bias, - ) - - atol = rtol = precision[ref_res.dtype] - torch.testing.assert_close(ref_res, target_res, atol=atol, rtol=rtol) - - def test_int4_awq_gemm(self): - for params in itertools.product( - self.M_awq, self.N_awq, self.K_awq, [128], self.has_bias - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - group_size=params[3], - has_bias=params[4], - ): - self._int4_awq_gemm(*params) - - def _int4_gptq_gemm(self, M, N, K, group_size, has_bias): - torch.manual_seed(127) - gptq_weight = torch.randint(-128, 128, (K // 8, N)).to(torch.int) - gptq_zero = torch.randint(0, 10, (K // group_size, N // 8)).to(torch.int) - gptq_scales = torch.rand(int(K // group_size), N).to(torch.bfloat16) // 10 - - bf16_weight = unpack_and_dequant_gptq(gptq_weight, gptq_zero, gptq_scales) - if has_bias: - bias = torch.rand(bf16_weight.shape[0]).to(torch.float) - else: - bias = None - x = torch.rand(M, bf16_weight.size(-1)).to(torch.bfloat16) - ref_res = torch.nn.functional.linear( - x, bf16_weight, bias=bias.to(torch.bfloat16) if has_bias else None - ) - - packed_weight, packed_zero, packed_scales = ( - torch.ops.sgl_kernel.convert_weight_packed_scale_zp( - gptq_weight, gptq_zero, gptq_scales, 1 - ) - ) - target_res = torch.ops.sgl_kernel.int4_scaled_mm_cpu( - x, - packed_weight, - packed_zero, - packed_scales, - bias, - ) - - atol = rtol = precision[ref_res.dtype] - torch.testing.assert_close(ref_res, target_res, atol=atol, rtol=rtol) - - def test_int4_gptq_gemm(self): - for params in itertools.product( - self.M_gptq, self.N_gptq, self.K_gptq, [128], self.has_bias - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - group_size=params[3], - has_bias=params[4], - ): - self._int4_gptq_gemm(*params) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_intel_amx_attention_backend_a.py b/test/registered/cpu/test_intel_amx_attention_backend_a.py deleted file mode 100644 index ba60527d2e25..000000000000 --- a/test/registered/cpu/test_intel_amx_attention_backend_a.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Usage: -python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_latency_default_model -""" - -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - intel_amx_benchmark, - is_in_ci, - popen_launch_server, -) - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -class TestIntelAMXAttnBackend(CustomTestCase): - - @intel_amx_benchmark( - extra_args=["--batch-size", "4", "--mem-fraction-static", "0.3"], - min_throughput=10, - ) - def test_latency_mla_model(self): - return DEFAULT_MLA_MODEL_NAME_FOR_TEST - - @intel_amx_benchmark( - extra_args=["--batch-size", "4", "--mem-fraction-static", "0.1"], - min_throughput=40, - ) - def test_latency_default_model(self): - return DEFAULT_MODEL_NAME_FOR_TEST - - def test_mmlu(self): - model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_TEST - process = popen_launch_server( - model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--attention-backend", - "intel_amx", - "--mem-fraction-static", - "0.3", - "--disable-radix", - "--trust-remote-code", - "--disable-overlap-schedule", - ], - ) - - try: - args = SimpleNamespace( - base_url=base_url, - model=model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - metrics = run_eval(args) - if is_in_ci(): - self.assertGreater(metrics["score"], 0.45) - finally: - kill_process_tree(process.pid) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_intel_amx_attention_backend_b.py b/test/registered/cpu/test_intel_amx_attention_backend_b.py deleted file mode 100644 index 104f417c5bd2..000000000000 --- a/test/registered/cpu/test_intel_amx_attention_backend_b.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -For intel_amx attention backend FP8 tests -Usage: -python3 -m unittest test_intel_amx_attention_backend_1.TestIntelAMXAttnBackendQuant.test_latency_fp8_qwen -""" - -import unittest - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, - DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8, - CustomTestCase, - intel_amx_benchmark, -) - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -class TestIntelAMXAttnBackendQuant(CustomTestCase): - - @intel_amx_benchmark( - extra_args=["--batch-size", "4", "--mem-fraction-static", "0.1"], - min_throughput=150, - ) - def test_latency_fp8_qwen(self): - return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 - - @intel_amx_benchmark( - extra_args=["--batch-size", "4", "--mem-fraction-static", "0.1"], - min_throughput=50, - ) - def test_latency_fp8_moe_model(self): - return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_intel_amx_attention_backend_c.py b/test/registered/cpu/test_intel_amx_attention_backend_c.py deleted file mode 100644 index 391d21a2b095..000000000000 --- a/test/registered/cpu/test_intel_amx_attention_backend_c.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -For intel_amx attention backend w8a8 tests -Usage: -python3 -m unittest test_intel_amx_attention_backend_2.TestIntelAMXAttnBackendQuant.test_latency_w8a8_default_model -""" - -import unittest - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST_W8A8, - DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, - CustomTestCase, - intel_amx_benchmark, -) - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -class TestIntelAMXAttnBackendQuant(CustomTestCase): - - @intel_amx_benchmark( - extra_args=[ - "--batch-size", - "4", - "--quantization", - "w8a8_int8", - "--mem-fraction-static", - "0.1", - ], - min_throughput=100, - ) - def test_latency_w8a8_default_model(self): - return DEFAULT_MODEL_NAME_FOR_TEST_W8A8 - - @intel_amx_benchmark( - extra_args=[ - "--batch-size", - "4", - "--quantization", - "w8a8_int8", - "--mem-fraction-static", - "0.9", - "--max-total-tokens", - "65536", - "--tp", - "6", - ], - min_throughput=100, - ) - def test_latency_w8a8_moe_model(self): - return DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_mamba.py b/test/registered/cpu/test_mamba.py deleted file mode 100644 index 25275e908e77..000000000000 --- a/test/registered/cpu/test_mamba.py +++ /dev/null @@ -1,397 +0,0 @@ -import unittest - -import torch -import torch.nn.functional as F -from torch.nn.functional import softplus -from utils import precision - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6): - """This function is intended to align with the l2norm implementation in the FLA library.""" - inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) - return x * inv_norm - - -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - chunk_size=64, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = l2norm(query, dim=-1, eps=1e-6) - key = l2norm(key, dim=-1, eps=1e-6) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, sequence_length, num_heads, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)) - key = F.pad(key, (0, 0, 0, pad_size)) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - tot_heads = num_heads + pad_size - scale = 1 / (query.shape[-1] ** 0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu( - torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), - diagonal=0, - ) - - # chunk decay - g = g.cumsum(dim=-1) - decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - last_recurrent_state = ( - torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) - if initial_state is None - else initial_state.to(value) - ) - core_attn_out = torch.zeros_like(value) - mask = torch.triu( - torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), - diagonal=1, - ) - - # for each chunk - for i in range(0, tot_heads // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() - + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2 - ) - @ v_new - ) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape( - core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] - ) - core_attn_out = core_attn_out[:, :, :num_heads] - core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state - - -def chunk_gated_delta_rule_update( - query, # [B, T, HK, K] - key, # [B, T, HK, K] - value, # [B, T, HV, V] - g, # [B, T, HV] - beta, # [B, T, HV] - cu_seqlens, # [N+1] - initial_state, # [N, HV, K, V] - use_qk_l2norm_in_kernel, # True -): - num_heads = query.shape[2] - num_value_heads = value.shape[2] - batch_size = initial_state.shape[0] - if num_value_heads // num_heads > 1: - query = query.repeat_interleave(num_value_heads // num_heads, dim=2) - key = key.repeat_interleave(num_value_heads // num_heads, dim=2) - output = torch.empty_like(value) - final_state = torch.empty_like(initial_state) - start_q = 0 - for i in range(batch_size): - end_q = cu_seqlens[i + 1] - core_attn_outi, last_recurrent_state = torch_chunk_gated_delta_rule( - query=query[:, start_q:end_q, :, :], - key=key[:, start_q:end_q, :, :], - value=value[:, start_q:end_q, :, :], - g=g[:, start_q:end_q, :], - beta=beta[:, start_q:end_q, :], - initial_state=initial_state[i], - output_final_state=True, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - output[:, start_q:end_q, :, :] = core_attn_outi - final_state[i] = last_recurrent_state - start_q = end_q - return output, final_state - - -def torch_recurrent_gated_delta_rule( - query, - key, - value, - g, - beta, - initial_state, - output_final_state, - use_qk_l2norm_in_kernel=False, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = l2norm(query, dim=-1, eps=1e-6) - key = l2norm(key, dim=-1, eps=1e-6) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - scale = 1 / (query.shape[-1] ** 0.5) - query = query * scale - - core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to( - value - ) - last_recurrent_state = ( - torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) - if initial_state is None - else initial_state.to(value) - ) - - for i in range(sequence_length): - q_t = query[:, :, i] - k_t = key[:, :, i] - v_t = value[:, :, i] - g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, i].unsqueeze(-1) - - last_recurrent_state = last_recurrent_state * g_t - kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - last_recurrent_state = last_recurrent_state + k_t.unsqueeze( - -1 - ) * delta.unsqueeze(-2) - core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state - - -def sigmoid_gating_delta_rule_update( - query, - key, - value, - A_log, - a, - dt_bias, - b, - initial_state, - output_final_state, - use_qk_l2norm_in_kernel=False, -): - beta = b.sigmoid() - g = -A_log.float().exp() * softplus(a.float() + dt_bias) - return torch_recurrent_gated_delta_rule( - query, - key, - value, - g.unsqueeze(0), - beta.unsqueeze(0), - initial_state, - output_final_state, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - - -def torch_gdn_gating(A_log, a, b, dt_bias): - return -A_log.float().exp() * softplus(a.float() + dt_bias).unsqueeze( - 0 - ), b.sigmoid().unsqueeze(0) - - -class TestMambaAttention(CustomTestCase): - def test_chunk_gated_delta_rule(self): - B, L, HK, HV, EK, EV, N = 1, 100, 3, 6, 64, 64, 4 - seqlens = torch.randint(1, L, (N + 1,)) - seqlens[0] = 0 - cu_seqlens_ = torch.cumsum(seqlens, dim=0).to(torch.int32) - T = cu_seqlens_[-1].item() - query_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05 - key_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05 - value_ = torch.rand((B, T, HV, EV), dtype=torch.bfloat16) * 0.05 - g_ = torch.rand((B, T, HV), dtype=torch.float32) * 0.05 - beta_ = torch.rand((B, T, HV), dtype=torch.bfloat16) * 0.05 - initial_state_ = torch.rand((N, HV, EK, EV), dtype=torch.float32) * 0.05 - - for use_qk_l2norm_in_kernel in [True, False]: - core_attn_out_ref, last_recurrent_state_ref = chunk_gated_delta_rule_update( - query=query_, - key=key_, - value=value_, - g=g_, - beta=beta_, - cu_seqlens=cu_seqlens_, - initial_state=initial_state_, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - - query = query_.clone() - key = key_.clone() - value = value_.clone() - g = g_.clone() - beta = beta_.clone() - cu_seqlens = cu_seqlens_.clone() - initial_state = initial_state_.clone() - - core_attn_out, last_recurrent_state = ( - torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( - query=query, - key=key, - value=value, - g=g, - beta=beta, - initial_state=initial_state, - output_final_state=True, - cu_seqlens=cu_seqlens, - head_first=False, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - ) - atol = rtol = precision[core_attn_out.dtype] - torch.testing.assert_close( - core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol - ) - torch.testing.assert_close( - last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol - ) - - def test_fused_gdn_gating(self): - dims = [6, 32] - for dim in dims: - for A_log_dtype in [torch.float32, torch.bfloat16]: - A_log = torch.rand(dim, dtype=A_log_dtype) - a = torch.rand(1024, dim, dtype=torch.bfloat16) - b = torch.rand(1024, dim, dtype=torch.bfloat16) - dt_bias = torch.rand(dim, dtype=torch.bfloat16) - - g, beta = torch_gdn_gating(A_log, a, b, dt_bias) - g_sgl, beta_sgl = torch.ops.sgl_kernel.fused_gdn_gating_cpu( - A_log, a, b, dt_bias - ) - atol = rtol = precision[g.dtype] - atol2 = rtol2 = precision[beta.dtype] - torch.testing.assert_close(g, g_sgl, atol=atol, rtol=rtol) - torch.testing.assert_close(beta, beta_sgl, atol=atol2, rtol=rtol2) - - def test_fused_sigmoid_gating_delta_rule_update(self): - batch_size = 1 - num_value_heads = 32 - head_k_dim = 128 - head_v_dim = 128 - num_heads = 16 - seq_len = 1 - attn_tp_size = 1 - key_dim = head_k_dim * num_heads - value_dim = head_v_dim * num_value_heads - mixed_qkv_dim = (key_dim * 2 + value_dim) // attn_tp_size - mixed_qkv = torch.rand( - seq_len * batch_size, mixed_qkv_dim, dtype=torch.bfloat16 - ) - query, key, value = torch.split( - mixed_qkv, - [ - key_dim // attn_tp_size, - key_dim // attn_tp_size, - value_dim // attn_tp_size, - ], - dim=-1, - ) - query = query.view(1, seq_len, num_heads, head_k_dim) - key = key.view(1, seq_len, num_heads, head_k_dim) - value = value.view(1, seq_len, num_value_heads, head_v_dim) - A_log = torch.rand(num_value_heads, dtype=torch.float32) - a = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) - b = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) - dt_bias = torch.rand(num_value_heads, dtype=torch.bfloat16) - ssm_states = torch.rand( - 513, num_value_heads, head_k_dim, head_v_dim, dtype=torch.float32 - ) - cache_indices = torch.randint(0, 513, (batch_size,), dtype=torch.int32) - query_start_loc = torch.tensor([0, 1], dtype=torch.int32) - use_qk_l2norm_in_kernel = True - query_ref = query.clone() - key_ref = key.clone() - if num_value_heads // num_heads > 1: - query_ref = query_ref.repeat_interleave(num_value_heads // num_heads, dim=2) - key_ref = key_ref.repeat_interleave(num_value_heads // num_heads, dim=2) - for A_log_dtype in [torch.float32, torch.bfloat16]: - A_log = A_log.to(A_log_dtype) - core_attn_out_ref, last_recurrent_state_ref = ( - sigmoid_gating_delta_rule_update( - query_ref.transpose(0, 1), - key_ref.transpose(0, 1), - value.transpose(0, 1), - A_log, - a, - dt_bias, - b, - initial_state=ssm_states[cache_indices], - output_final_state=True, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - ) - core_attn_out = ( - torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu( - A_log=A_log, - dt_bias=dt_bias, - q=query, - k=key, - v=value, - a=a, - b=b, - initial_state_source=ssm_states, - initial_state_indices=cache_indices, - cu_seqlens=query_start_loc, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - softplus_beta=1.0, - softplus_threshold=20.0, - ) - ) - last_recurrent_state = ssm_states[cache_indices] - atol = rtol = precision[core_attn_out.dtype] - torch.testing.assert_close( - core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol - ) - torch.testing.assert_close( - last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_mla.py b/test/registered/cpu/test_mla.py deleted file mode 100644 index f6714ced5ac5..000000000000 --- a/test/registered/cpu/test_mla.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest - -import torch -from torch.nn.functional import scaled_dot_product_attention -from utils import precision - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestMLA(CustomTestCase): - def _run_sdpa_forward_decode( - self, - query: torch.Tensor, - output: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key: torch.Tensor, - loc: torch.Tensor, - req_to_token: torch.Tensor, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - scaling=None, - enable_gqa=False, - causal=False, - ): - # set kv cache - k_cache[loc] = key - - # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] - query = query.movedim(0, query.dim() - 2) - - start_q, start_kv = 0, 0 - for seq_idx in range(seq_lens.shape[0]): - seq_len_q = 1 - seq_len_kv = seq_lens[seq_idx] - end_q = start_q + seq_len_q - end_kv = start_kv + seq_len_kv - - per_req_query = query[:, start_q:end_q, :] - - # get key and value from cache. per_req_tokens contains the kv cache - # index for each token in the sequence. - req_pool_idx = req_pool_indices[seq_idx] - per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] - per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) - per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) - - per_req_out = ( - scaled_dot_product_attention( - per_req_query.unsqueeze(0), - per_req_key.unsqueeze(0), - per_req_value.unsqueeze(0), - enable_gqa=enable_gqa, - scale=scaling, - is_causal=causal, - ) - .squeeze(0) - .movedim(query.dim() - 2, 0) - ) - output[start_q:end_q, :, :] = per_req_out - start_q, start_kv = end_q, end_kv - - return output - - def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len): - dtype = torch.bfloat16 - - total_tokens = B * seq_len - sm_scale = 1.0 / (D**0.5) - logit_cap = 0.0 - num_kv_splits = 8 - enable_gqa = H_Q != H_KV - - # q represents the new token being generated, one per batch - q = torch.randn(B, H_Q, D, dtype=dtype) - - # k_buffer and v_buffer represent all previous tokens - k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype) - v_buffer = k_buffer.narrow(2, 0, D_V) - - key = torch.randn(B, H_KV, D, dtype=dtype) - value = key.narrow(2, 0, D_V) - # make sure no duplicates in loc - loc = torch.randperm(total_tokens)[:B].to(torch.int64) - - k_buffer2 = k_buffer.clone() - v_buffer2 = k_buffer2.narrow(2, 0, D_V) - - # o will have the same shape as q - o = torch.zeros(B, H_Q, D_V, dtype=dtype) - o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype) - - req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32) - b_req_idx = torch.arange(B).to(torch.int64) - b_seq_len = torch.full((B,), seq_len).to(torch.int64) - - attn_logits = torch.empty( - (B, H_Q, num_kv_splits, D_V + 1), - dtype=torch.float32, - ) - - torch.ops.sgl_kernel.decode_attention_cpu( - q, - k_buffer2, - v_buffer2, - o, - key, - value, - loc, - attn_logits, - req_to_token, - b_req_idx, - b_seq_len, - sm_scale, - logit_cap, - ) - - self._run_sdpa_forward_decode( - q, - o_grouped, - k_buffer, - v_buffer, - key, - loc, - req_to_token, - b_req_idx, - b_seq_len, - scaling=sm_scale, - enable_gqa=enable_gqa, - ) - - cos_sim = torch.nn.functional.cosine_similarity( - o.flatten(), o_grouped.flatten(), dim=0 - ) - atol = rtol = precision[q.dtype] - self.assertGreater(cos_sim.item(), 0.99) - torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol) - torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol) - torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol) - - def test_grouped_decode_attention(self): - configs = [ - (1, 22, 1, 576, 512, 8 * 111), - (4, 22, 1, 576, 512, 8 * 128), - (40, 22, 1, 576, 512, 8 * 133), - ] - - for B, H_Q, H_KV, D, D_V, seqlen in configs: - self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_moe.py b/test/registered/cpu/test_moe.py deleted file mode 100644 index f3c817bec331..000000000000 --- a/test/registered/cpu/test_moe.py +++ /dev/null @@ -1,355 +0,0 @@ -import itertools -import math -import unittest - -# TODO: use interface in cpu.py -import torch - -from sglang.srt.layers.amx_utils import CPUQuantMethod - -kernel = torch.ops.sgl_kernel - -torch.manual_seed(128) - -from utils import ( - BLOCK_K, - BLOCK_N, - factor_for_scale, - fp8_max, - fp8_min, - native_fp8_fused_moe, - precision, - scaled_weight, - torch_naive_fused_moe, - torch_w8a8_per_column_fused_moe, - unpack_and_dequant_awq, -) - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -def fused_moe(a, w1, w2, score, topk, renormalize, prepack): - - G = 1 - topk_group = 1 - - B, D = a.shape - topk_weights = torch.empty(B, topk, dtype=torch.float32) - topk_ids = torch.empty(B, topk, dtype=torch.int32) - topk_weights, topk_ids = kernel.grouped_topk_cpu( - a, score, topk, renormalize, G, topk_group, 0, None, None - ) - - packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 - packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2 - - inplace = True - return kernel.fused_experts_cpu( - a, - packed_w1, - packed_w2, - topk_weights, - topk_ids, - inplace, - CPUQuantMethod.UNQUANT, - None, - None, - None, - None, - None, - prepack, - ) - - -class TestFusedExperts(CustomTestCase): - M = [2, 114] - N = [32] - K = [32] - E = [4] - topk = [2] - renormalize = [False, True] - - M_int8 = [1, 39] - N_int8 = [128] - K_int8 = [256] - E_int8 = [8] - topk_int8 = [3] - - M_fp8 = [2, 121] - N_fp8 = [352, 512] - K_fp8 = [256, 320] - E_fp8 = [8] - topk_fp8 = [4] - - M_int4 = [1, 6] - N_int4 = [512] - K_int4 = [256] - E_int4 = [8] - topk_int4 = [4] - - def _bf16_moe(self, m, n, k, e, topk, renormalize): - dtype = torch.bfloat16 - prepack = True - - a = torch.randn((m, k), device="cpu", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10 - score = torch.randn((m, e), device="cpu", dtype=dtype) - - torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize) - fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack) - - atol = rtol = precision[torch_output.dtype] - torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol) - - def test_bf16_moe(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.E, - self.topk, - self.renormalize, - ): - with self.subTest( - m=params[0], - n=params[1], - k=params[2], - e=params[3], - topk=params[4], - renormalize=params[5], - ): - self._bf16_moe(*params) - - def _int8_moe(self, M, N, K, E, topk): - dtype = torch.bfloat16 - prepack = True - - # Initialize int8 quantization parameters - int8_factor_for_scale = 1e-2 - int8_max = 127 - int8_min = -128 - - # Input tensor - # M * K - a = torch.randn((M, K), dtype=dtype) / math.sqrt(K) - - # Generate int8 weights - w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 - w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) - - w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 - w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) - - # Generate scale for each column (per-column quantization) - w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale - w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale - - # Calculate routing - score = torch.randn((M, E), dtype=dtype) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - - ref_out = torch_w8a8_per_column_fused_moe( - a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk - ) - - inplace = True - packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 - packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2 - out = kernel.fused_experts_cpu( - a, - packed_w1, - packed_w2, - topk_weight, - topk_ids.to(torch.int32), - inplace, - CPUQuantMethod.INT8_W8A8, - w1_s, - w2_s, - None, - None, - None, - prepack, - ) - - atol = rtol = precision[ref_out.dtype] - # Increase the tolerance for large input shapes - if M > 35: - atol = rtol = 0.02 - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - def test_int8_moe(self): - for params in itertools.product( - self.M_int8, - self.N_int8, - self.K_int8, - self.E_int8, - self.topk_int8, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - E=params[3], - topk=params[4], - ): - self._int8_moe(*params) - - def _fp8_moe(self, M, N, K, E, topk): - dtype = torch.bfloat16 - - a = torch.randn(M, K, dtype=dtype) / math.sqrt(K) - - w1_fp32 = torch.randn(E, 2 * N, K) - w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - w2_fp32 = torch.randn(E, K, N) - w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - w1s = ( - torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K)) - * factor_for_scale - ) - w2s = ( - torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K)) - * factor_for_scale - ) - - w1_scaled = scaled_weight(w1, w1s) - w2_scaled = scaled_weight(w2, w2s) - - score = torch.randn((M, E), dtype=dtype) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - - w1 = kernel.convert_weight_packed(w1) - w2 = kernel.convert_weight_packed(w2) - - ref_out = native_fp8_fused_moe( - a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk - ) - out = kernel.fused_experts_cpu( - a, - w1, - w2, - topk_weight, - topk_ids.to(torch.int32), - False, - CPUQuantMethod.FP8_W8A16, - w1s, - w2s, - None, - None, - [BLOCK_N, BLOCK_K], - True, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol) - - def test_fp8_moe(self): - for params in itertools.product( - self.M_fp8, - self.N_fp8, - self.K_fp8, - self.E_fp8, - self.topk_fp8, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - E=params[3], - topk=params[4], - ): - self._fp8_moe(*params) - - def _int4_moe(self, M, N, K, E, topk, group_size=128): - dtype = torch.bfloat16 - - a = torch.rand(M, K, dtype=dtype) / math.sqrt(K) - - awq_w13_weight = torch.randint(-127, 128, (E, K, 2 * N // 8)).to(torch.int) - awq_w13_zero = torch.randint(0, 10, (E, K // group_size, 2 * N // 8)).to( - torch.int - ) - awq_w13_scales = torch.rand(E, int(K // group_size), 2 * N).to(torch.bfloat16) - - awq_w2_weight = torch.randint(-127, 128, (E, N, K // 8)).to(torch.int) - awq_w2_zero = torch.randint(0, 10, (E, N // group_size, K // 8)).to(torch.int) - awq_w2_scales = torch.rand(E, int(N // group_size), K).to(torch.bfloat16) - bf16_w13_weight = [] - bf16_w2_weight = [] - for i in range(E): - bf16_w13_weight_i, _ = unpack_and_dequant_awq( - awq_w13_weight[i], awq_w13_zero[i], awq_w13_scales[i], 4, 128 - ) - bf16_w2_weight_i, _ = unpack_and_dequant_awq( - awq_w2_weight[i], awq_w2_zero[i], awq_w2_scales[i], 4, 128 - ) - bf16_w13_weight.append(bf16_w13_weight_i) - bf16_w2_weight.append(bf16_w2_weight_i) - bf16_w13_weight = torch.stack(bf16_w13_weight).detach() - bf16_w2_weight = torch.stack(bf16_w2_weight).detach() - - score = torch.rand((M, E), dtype=dtype) - - ref_out = torch_naive_fused_moe( - a, bf16_w13_weight, bf16_w2_weight, score, topk, False - ) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - awq_w13_weight_pack, awq_w13_zero_pack, awq_w13_scales_pack = ( - torch.ops.sgl_kernel.convert_weight_packed_scale_zp( - awq_w13_weight, awq_w13_zero, awq_w13_scales, 0 - ) - ) - awq_w2_weight_pack, awq_w2_zero_pack, awq_w2_scales_pack = ( - torch.ops.sgl_kernel.convert_weight_packed_scale_zp( - awq_w2_weight, awq_w2_zero, awq_w2_scales, 0 - ) - ) - - out = kernel.fused_experts_cpu( - a, - awq_w13_weight_pack, - awq_w2_weight_pack, - topk_weight, - topk_ids.to(torch.int32), - False, - CPUQuantMethod.INT4_W4A8, - awq_w13_scales_pack, - awq_w2_scales_pack, - awq_w13_zero_pack, - awq_w2_zero_pack, - None, - True, - ) - - atol = rtol = precision[dtype] - torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol) - - def test_int4_moe(self): - for params in itertools.product( - self.M_int4, - self.N_int4, - self.K_int4, - self.E_int4, - self.topk_int4, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - E=params[3], - topk=params[4], - ): - self._int4_moe(*params) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_norm.py b/test/registered/cpu/test_norm.py deleted file mode 100644 index 2da4d5117a99..000000000000 --- a/test/registered/cpu/test_norm.py +++ /dev/null @@ -1,435 +0,0 @@ -import itertools -import unittest -from typing import Optional, Tuple, Union - -import torch -from utils import make_non_contiguous, parametrize, precision - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestNorm(CustomTestCase): - - def _forward_native( - self, - x: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float = 1e-6, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = x.dtype - x = x.to(torch.float32) - if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(orig_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - x = x.to(orig_dtype) * weight - if residual is None: - return x - else: - return x, residual - - def _norm(self, x, eps): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - def _gemma3_rmsnorm_native( - self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6 - ): - output = self._norm(x.float(), variance_epsilon) - output = output * (1.0 + weight.float()) - return output.type_as(x) - - def _gemma_rmsnorm_native( - self, - x: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float = 1e-6, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = x.dtype - if residual is not None: - x = x + residual - residual = x - - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - x = x * (1.0 + weight.float()) - x = x.to(orig_dtype) - return x if residual is None else (x, residual) - - @parametrize( - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_norm(self, m, n, dtype): - - x = torch.randn([m, n], dtype=dtype) - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon) - ref_out = self._forward_native(x, weight, variance_epsilon) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - ref_x = x.clone() - residual = torch.randn([m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( - x, residual, weight, variance_epsilon - ) - - ref_x, ref_residual = self._forward_native( - ref_x, weight, variance_epsilon, ref_residual - ) - - torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - @parametrize( - l=[1, 2], - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_norm_3d(self, l, m, n, dtype): - - x = torch.randn([l, m, n], dtype=dtype) - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon) - ref_out = self._forward_native(x, weight, variance_epsilon) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - ref_x = x.clone() - residual = torch.randn([l, m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( - x, residual, weight, variance_epsilon - ) - - ref_x, ref_residual = self._forward_native( - ref_x, weight, variance_epsilon, ref_residual - ) - - torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - @parametrize( - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_l2norm(self, m, n, dtype): - - x = torch.randn([m, n], dtype=dtype) - hidden_size = x.size(-1) - fake_ones_weight = torch.ones(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon) - ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - @parametrize( - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_gemma_rmsnorm(self, m, n, dtype): - - x = torch.randn([m, n], dtype=dtype) - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - out = torch.ops.sgl_kernel.gemma_rmsnorm_cpu(x, weight, variance_epsilon) - ref_out = self._gemma_rmsnorm_native(x, weight, variance_epsilon) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - ref_x = x.clone() - residual = torch.randn([m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu( - x, residual, weight, variance_epsilon - ) - - ref_x, ref_residual = self._gemma_rmsnorm_native( - ref_x, weight, variance_epsilon, ref_residual - ) - - torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - @parametrize( - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_gemma3_rmsnorm(self, m, n, dtype): - x_list = [ - torch.randn([m, n], dtype=dtype), - torch.randn([1, m, 2, n], dtype=dtype), - ] - for x in x_list: - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - out = torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, weight, variance_epsilon) - ref_out = self._gemma3_rmsnorm_native(x, weight, variance_epsilon) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - def _gemma4_rmsnorm_native( - self, - x: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float = 1e-6, - scale_shift: float = 0.0, - with_scale: bool = True, - ): - output = self._norm(x.float(), variance_epsilon) - if with_scale: - output = output * (weight.float() + scale_shift) - return output.type_as(x) - - @parametrize( - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_gemma4_rmsnorm(self, m, n, dtype): - for scale_shift, with_scale in [ - (0.0, True), - (1.0, True), - (0.0, False), - (1.0, False), - ]: - x_list = [ - torch.randn([m, n], dtype=dtype), - torch.randn([4, m, n], dtype=dtype), - ] - # Add non-block-contiguous 3D input - base = torch.randn([4, 2 * m, n], dtype=dtype) - x_list.append(base[:, :m, :]) - - for x in x_list: - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - out = torch.ops.sgl_kernel.gemma4_rmsnorm_cpu( - x, weight, variance_epsilon, scale_shift, with_scale - ) - ref_out = self._gemma4_rmsnorm_native( - x, weight, variance_epsilon, scale_shift, with_scale - ) - - atol = rtol = precision[ref_out.dtype] - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - -class TestFusedRMSNormGated(CustomTestCase): - M = [4096, 1024] - N = [4096, 4096 + 13] - dtype = [torch.float16, torch.bfloat16] - - def _forward_native( - self, - hidden_states: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float = 1e-6, - gate: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - # Norm before gate - hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) - hidden_states = weight * hidden_states.to(input_dtype) - hidden_states = hidden_states * torch.nn.functional.silu(gate.to(torch.float32)) - - return hidden_states.to(input_dtype) - - def _norm_test(self, m, n, dtype): - - x = torch.randn([m, n], dtype=dtype) - x = make_non_contiguous(x) - batch_size = x.size(0) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - gate = torch.randn([batch_size, hidden_size], dtype=dtype) - - out = torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( - x, weight, gate, variance_epsilon - ) - ref_out = self._forward_native(x, weight, variance_epsilon, gate) - - atol = rtol = precision[ref_out.dtype] * 2 - torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) - - def test_norm(self): - for params in itertools.product(self.M, self.N, self.dtype): - with self.subTest(m=params[0], n=params[1], dtype=params[2]): - self._norm_test(*params) - - -class TestLayerNorm(CustomTestCase): - - def _forward_native( - self, - x: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, - residual: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = x.dtype - x = x.to(torch.float32) - if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(orig_dtype) - - variance, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0) - x = (x - mean) * torch.rsqrt(variance + variance_epsilon) - x = x * weight.to(torch.float32) - if bias is not None: - x = x + bias.to(torch.float32) - x = x.to(orig_dtype) - return x if residual is None else (x, residual) - - @parametrize( - m=[4096, 1024], - n=[4096, 4109], - dtype=[torch.float16, torch.bfloat16], - ) - def test_norm_input_2d(self, m: int, n: int, dtype: torch.dtype) -> None: - x = torch.randn([m, n], dtype=dtype) - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - bias = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, None, variance_epsilon) - ref_ln_out = self._forward_native(x, weight, variance_epsilon) - - atol = rtol = precision[ref_ln_out.dtype] - torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) - - ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, bias, variance_epsilon) - ref_ln_out = self._forward_native( - x, weight, variance_epsilon, residual=None, bias=bias - ) - torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) - - residual = torch.randn([m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( - x, residual, weight, None, variance_epsilon - ) - ref_add_ln_out, ref_residual = self._forward_native( - x, weight, variance_epsilon, residual=ref_residual - ) - - torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - residual = torch.randn([m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( - x, residual, weight, bias, variance_epsilon - ) - ref_add_ln_out, ref_residual = self._forward_native( - x, weight, variance_epsilon, residual=ref_residual, bias=bias - ) - - torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - @parametrize( - l=[4096, 1024], - m=[1, 4], - n=[4096, 4109, 2304], - dtype=[torch.float16, torch.bfloat16], - ) - def test_norm_input_3d(self, l: int, m: int, n: int, dtype: torch.dtype) -> None: - x = torch.randn([l, m, n], dtype=dtype) - x = make_non_contiguous(x) - hidden_size = x.size(-1) - weight = torch.randn(hidden_size, dtype=dtype) - bias = torch.randn(hidden_size, dtype=dtype) - variance_epsilon = 1e-6 - - ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, None, variance_epsilon) - ref_ln_out = self._forward_native(x, weight, variance_epsilon) - - atol = rtol = precision[ref_ln_out.dtype] - torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) - - ln_out = torch.ops.sgl_kernel.layernorm_cpu(x, weight, bias, variance_epsilon) - ref_ln_out = self._forward_native( - x, weight, variance_epsilon, residual=None, bias=bias - ) - torch.testing.assert_close(ln_out, ref_ln_out, atol=atol, rtol=rtol) - - residual = torch.randn([l, m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( - x, residual, weight, None, variance_epsilon - ) - ref_add_ln_out, ref_residual = self._forward_native( - x, weight, variance_epsilon, ref_residual - ) - - torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - residual = torch.randn([l, m, hidden_size], dtype=dtype) - ref_residual = residual.clone() - - add_ln_out = torch.ops.sgl_kernel.fused_add_layernorm_cpu( - x, residual, weight, bias, variance_epsilon - ) - ref_add_ln_out, ref_residual = self._forward_native( - x, weight, variance_epsilon, residual=ref_residual, bias=bias - ) - - torch.testing.assert_close(add_ln_out, ref_add_ln_out, atol=atol, rtol=rtol) - torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_qkv_proj_with_rope.py b/test/registered/cpu/test_qkv_proj_with_rope.py deleted file mode 100644 index 2fa291860f5d..000000000000 --- a/test/registered/cpu/test_qkv_proj_with_rope.py +++ /dev/null @@ -1,443 +0,0 @@ -import unittest - -import torch -from utils import ( - convert_weight, - native_w8a8_per_token_matmul, - per_token_quant_int8, - precision, -) - -from sglang.srt.layers.quantization.fp8_utils import input_to_float8 -from sglang.srt.layers.rotary_embedding.utils import apply_rotary_emb -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed -qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope -qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight -torch.manual_seed(1234) -# constants -kv_lora_rank = 512 -qk_head_dim = 192 -qk_nope_head_dim = 128 -qk_rope_head_dim = 64 -rotary_dim = qk_rope_head_dim -num_heads = 22 -q_lora_rank = 1536 -hidden_size = 7168 -B = 1 -eps = 1e-6 - - -def layernorm(x, weight, variance_epsilon=1e-6, residual=None): - orig_dtype = x.dtype - x = x.to(torch.float32) - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - return (x * weight).to(orig_dtype) - - -def rotary_emb(q_pe, k_pe, pos, cos_sin_cache): - orig_dtype = q_pe.dtype - q_pe = q_pe.float() - k_pe = k_pe.float() - cos_sin_cache = cos_sin_cache.float() - - query_rot = q_pe[..., :rotary_dim] - key_rot = k_pe[..., :rotary_dim] - cos_sin = cos_sin_cache[pos] - cos, sin = cos_sin.chunk(2, dim=-1) - query_rot = apply_rotary_emb(query_rot, cos, sin, False) - key_rot = apply_rotary_emb(key_rot, cos, sin, False) - return query_rot.to(orig_dtype), key_rot.to(orig_dtype) - - -def native_torch( - q_input, - hidden_states, - q_a_proj_weight, - norm_weight1, - q_b_proj_weight, - w_kc, - kv_a_proj_weight, - norm_weight2, - pos, - cos_sin_cache, -): - - q = torch.matmul(hidden_states, q_a_proj_weight.t()) - q = layernorm(q, norm_weight1) - q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim) - - q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) - - q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1) - latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t()) - v_input = latent_cache[..., :kv_lora_rank] - v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1) - k_input = latent_cache.unsqueeze(1) - k_input[..., :kv_lora_rank] = v_input - k_pe = k_input[..., kv_lora_rank:] - - q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache) - q_input[..., kv_lora_rank:] = q_pe - k_input[..., kv_lora_rank:] = k_pe - - return q_input, k_input, v_input - - -def native_torch_int8( - q_input, - hidden_states, - w1_q, - w1_s, - norm_weight1, - w2_q, - w2_s, - w_kc, - w3_q, - w3_s, - norm_weight2, - pos, - cos_sin_cache, -): - - a_q, a_s = per_token_quant_int8(hidden_states) - q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16) - q = layernorm(q, norm_weight1) - - a_q, a_s = per_token_quant_int8(q) - q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view( - -1, num_heads, qk_head_dim - ) - - q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) - - q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1) - a_q, a_s = per_token_quant_int8(hidden_states) - latent_cache = native_w8a8_per_token_matmul( - a_q, w3_q, a_s, w3_s, None, torch.bfloat16 - ) - v_input = latent_cache[..., :kv_lora_rank] - v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1) - k_input = latent_cache.unsqueeze(1) - k_input[..., :kv_lora_rank] = v_input - k_pe = k_input[..., kv_lora_rank:] - - q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache) - q_input[..., kv_lora_rank:] = q_pe - k_input[..., kv_lora_rank:] = k_pe - - return q_input, k_input, v_input - - -class TestQKVProjWithROPE(CustomTestCase): - def test_bf16_qkv_proj_with_rope(self): - dtype = torch.bfloat16 - hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size - q_input = torch.empty( - B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype - ) - q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 - norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) - q_b_proj_weight = ( - torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 - ) - w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 - kv_a_proj_weight = ( - torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 - ) - fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0) - norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) - pos = torch.randint(10, 100, (B,)) - cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) - q_ref, k_ref, v_ref = native_torch( - q_input, - hidden_states, - q_a_proj_weight, - norm_weight1, - q_b_proj_weight, - w_kc.transpose(1, 2), - kv_a_proj_weight, - norm_weight2, - pos, - cos_sin_cache, - ) - qa_packed = convert_weight_packed(q_a_proj_weight) - qb_packed = convert_weight_packed(q_b_proj_weight) - kva_packed = convert_weight_packed(kv_a_proj_weight) - wkc_packed = convert_weight_packed(w_kc) - fused_weight_packed = convert_weight_packed(fused_weight) - - q_out, k_out, v_out = qkv_proj_with_rope( - hidden_states, - qa_packed, - qb_packed, - kva_packed, - wkc_packed, - norm_weight1, - norm_weight2, - pos, - cos_sin_cache, - eps, - False, - False, - None, - None, - None, - None, - True, - None, - ) - fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( - hidden_states, - fused_weight_packed, - qb_packed, - wkc_packed, - norm_weight1, - norm_weight2, - pos, - cos_sin_cache, - eps, - False, - False, - None, - None, - None, - True, - None, - q_lora_rank, - kv_lora_rank, - qk_rope_head_dim, - ) - atol = rtol = precision[q_ref.dtype] - torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol) - torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) - torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) - torch.testing.assert_close(fused_q_out, q_out) - torch.testing.assert_close(fused_k_out, k_out) - torch.testing.assert_close(fused_v_out, v_out) - - def test_int8_qkv_proj_with_rope(self): - dtype = torch.bfloat16 - hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size - q_input = torch.empty( - B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype - ) - q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 - norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) - q_b_proj_weight = ( - torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 - ) - w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 - kv_a_proj_weight = ( - torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 - ) - norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) - pos = torch.randint(10, 100, (B,)) - cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) - - w1_q, w1_s = per_token_quant_int8(q_a_proj_weight) - w2_q, w2_s = per_token_quant_int8(q_b_proj_weight) - w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight) - q_ref, k_ref, v_ref = native_torch_int8( - q_input, - hidden_states, - w1_q, - w1_s, - norm_weight1, - w2_q, - w2_s, - w_kc.transpose(1, 2), - w3_q, - w3_s, - norm_weight2, - pos, - cos_sin_cache, - ) - w1_q_packed = convert_weight_packed(w1_q) - w2_q_packed = convert_weight_packed(w2_q) - w3_q_packed = convert_weight_packed(w3_q) - wkc_packed = convert_weight_packed(w_kc) - q_out, k_out, v_out = qkv_proj_with_rope( - hidden_states, - w1_q_packed, - w2_q_packed, - w3_q_packed, - wkc_packed, - norm_weight1, - norm_weight2, - pos, - cos_sin_cache, - eps, - True, - False, - w1_s, - w2_s, - w3_s, - None, - True, - None, - ) - fused_weight = torch.cat([w1_q, w3_q], dim=0) - fused_weight_s = torch.cat([w1_s, w3_s], dim=0) - w_fused_q_packed = convert_weight_packed(fused_weight) - fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( - hidden_states, - w_fused_q_packed, - w2_q_packed, - wkc_packed, - norm_weight1, - norm_weight2, - pos, - cos_sin_cache, - eps, - True, - False, - fused_weight_s, - w2_s, - None, - True, - None, - q_lora_rank, - kv_lora_rank, - qk_rope_head_dim, - ) - atol = rtol = precision[q_ref.dtype] - torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol) - torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) - torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) - torch.testing.assert_close(fused_q_out, q_out) - torch.testing.assert_close(fused_k_out, k_out) - torch.testing.assert_close(fused_v_out, v_out) - - def test_fp8_qkv_proj_with_rope(self): - dtype = torch.bfloat16 - hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size - q_input = torch.empty( - B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype - ) - q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 - norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) - q_b_proj_weight = ( - torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 - ) - w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 - w_kc_q, w_kc_s = input_to_float8(w_kc) - kv_a_proj_weight = ( - torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 - ) - norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) - pos = torch.randint(10, 100, (B,)) - cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) - - scale_block_size_N = 128 - scale_block_size_K = 128 - fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = ( - convert_weight( - q_a_proj_weight, - [scale_block_size_N, scale_block_size_K], - torch.bfloat16, - ) - ) - fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = ( - convert_weight( - q_b_proj_weight, - [scale_block_size_N, scale_block_size_K], - torch.bfloat16, - ) - ) - ( - fp8_kv_a_proj_with_mqa_weight, - kv_a_proj_with_mqa_weight_scale_inv, - kv_a_proj_with_mqa_weight_dq, - ) = convert_weight( - kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16 - ) - w_kc_dq = w_kc_q.to(torch.bfloat16) * w_kc_s - q_ref, k_ref, v_ref = native_torch( - q_input, - hidden_states, - q_a_proj_weight_dq, - norm_weight1, - q_b_proj_weight_dq, - w_kc_dq.transpose(1, 2), - kv_a_proj_with_mqa_weight_dq, - norm_weight2, - pos, - cos_sin_cache, - ) - fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight) - fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight) - fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed( - fp8_kv_a_proj_with_mqa_weight - ) - w_kc_q = convert_weight_packed(w_kc_q) - q_out, k_out, v_out = qkv_proj_with_rope( - hidden_states, - fp8_q_a_proj_weight_packed, - fp8_q_b_proj_weight_packed, - fp8_kv_a_proj_with_mqa_weight_packed, - w_kc_q, - norm_weight1, - norm_weight2, - pos, - cos_sin_cache, - eps, - False, - True, - q_a_proj_weight_scale_inv.float(), - q_b_proj_weight_scale_inv.float(), - kv_a_proj_with_mqa_weight_scale_inv.float(), - w_kc_s, - True, - [scale_block_size_N, scale_block_size_K], - ) - - fused_weight = torch.cat( - [fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0 - ) - fused_weight_s = torch.cat( - [q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0 - ) - fused_weight_packed = convert_weight_packed(fused_weight) - fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( - hidden_states, - fused_weight_packed, - fp8_q_b_proj_weight_packed, - w_kc_q, - norm_weight1, - norm_weight2, - pos, - cos_sin_cache, - eps, - False, - True, - fused_weight_s.float(), - q_b_proj_weight_scale_inv.float(), - w_kc_s, - True, - [scale_block_size_N, scale_block_size_K], - q_lora_rank, - kv_lora_rank, - qk_rope_head_dim, - ) - atol = rtol = precision[q_ref.dtype] - # Due to the change in multiplication order, the error is amplified. - # In the model, with fewer layers, this doesn't cause issues, but in - # tests with more layers, we need to enlarge the tolerance to pass the tests. - torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1) - torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) - torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) - torch.testing.assert_close(fused_q_out, q_out) - torch.testing.assert_close(fused_k_out, k_out) - torch.testing.assert_close(fused_v_out, v_out) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_qwen3.py b/test/registered/cpu/test_qwen3.py deleted file mode 100644 index e8edfeb649ae..000000000000 --- a/test/registered/cpu/test_qwen3.py +++ /dev/null @@ -1,152 +0,0 @@ -import unittest - -import torch -from utils import precision - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -def fix_query_key_value_ordering_reshape_cat( - mixed_qkvz, mixed_ba, num_k_heads, num_v_heads, attn_tp_size, head_k_dim, head_v_dim -): - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - num_k_heads // attn_tp_size, - ( - head_k_dim - + head_k_dim - + (head_v_dim + head_v_dim) * num_v_heads // num_k_heads - ), - ) - new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - num_k_heads // attn_tp_size, - 2 * num_v_heads // num_k_heads, - ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - - split_arg_list_qkvz = [ - head_k_dim, - head_k_dim, - (num_v_heads // num_k_heads * head_v_dim), - (num_v_heads // num_k_heads * head_v_dim), - ] - split_arg_list_ba = [ - num_v_heads // num_k_heads, - num_v_heads // num_k_heads, - ] - # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] - # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) - b, a = torch.split(mixed_ba, split_arg_list_ba, dim=2) - - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), -1, head_v_dim) - z = z.reshape(z.size(0), -1, head_v_dim) - b = b.reshape(b.size(0), num_v_heads // attn_tp_size) - a = a.reshape(a.size(0), num_v_heads // attn_tp_size) - query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value)) - mixed_qkv = torch.cat((query, key, value), dim=-1) - - return mixed_qkv, z, b, a - - -def fix_query_key_value_ordering_reshape_cat_contiguous( - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, - key_dim: int, - value_dim: int, - num_v_heads: int, - head_v_dim: int, - attn_tp_size: int, -): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - k_tp = key_dim // attn_tp_size - v_tp = value_dim // attn_tp_size - nv_tp = num_v_heads // attn_tp_size - - # Directly split, no head group reshape - query, key, value, z = mixed_qkvz.split([k_tp, k_tp, v_tp, v_tp], dim=-1) - b, a = mixed_ba.split([nv_tp, nv_tp], dim=-1) - - # value / z reshape to (seq, num_v_heads/tp, head_v_dim) - value = value.reshape(value.size(0), -1, head_v_dim) - z = z.reshape(z.size(0), -1, head_v_dim) - query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value)) - mixed_qkv = torch.cat((query, key, value), dim=-1) - return mixed_qkv, z, b, a - - -class TestQwen3(CustomTestCase): - def test_fused_qkvzba_split_reshape_cat(self): - mixed_qkvz = torch.rand(1024, 12288, dtype=torch.bfloat16) - mixed_ba = torch.rand(1024, 64, dtype=torch.bfloat16) - head_k_dim = 128 - head_v_dim = 128 - num_v_heads = 32 - num_k_heads = 16 - attn_tp_size = 1 - mixed_qkv_ref, z_ref, b_ref, a_ref = fix_query_key_value_ordering_reshape_cat( - mixed_qkvz, - mixed_ba, - num_k_heads, - num_v_heads, - attn_tp_size, - head_k_dim, - head_v_dim, - ) - num_heads_qk = num_k_heads // attn_tp_size - num_heads_v = num_v_heads // attn_tp_size - mixed_qkv, z, b, a = torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( - mixed_qkvz, mixed_ba, num_heads_qk, num_heads_v, head_k_dim, head_v_dim - ) - atol = rtol = precision[mixed_qkv.dtype] - torch.testing.assert_close(mixed_qkv, mixed_qkv_ref, atol=atol, rtol=rtol) - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) - torch.testing.assert_close(b, b_ref, atol=atol, rtol=rtol) - torch.testing.assert_close(a, a_ref, atol=atol, rtol=rtol) - - def test_fused_qkvzba_split_reshape_cat_contiguous(self): - mixed_qkvz = torch.rand(1, 12288, dtype=torch.bfloat16) - mixed_ba = torch.rand(1, 64, dtype=torch.bfloat16) - head_k_dim = 128 - head_v_dim = 128 - num_v_heads = 32 - num_k_heads = 16 - attn_tp_size = 1 - key_dim = head_k_dim * num_k_heads - value_dim = head_v_dim * num_v_heads - mixed_qkv_ref, z_ref, b_ref, a_ref = ( - fix_query_key_value_ordering_reshape_cat_contiguous( - mixed_qkvz, - mixed_ba, - key_dim, - value_dim, - num_v_heads, - head_v_dim, - attn_tp_size, - ) - ) - num_heads_qk = num_k_heads // attn_tp_size - num_heads_v = num_v_heads // attn_tp_size - mixed_qkv, z, b, a = ( - torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_contiguous_cpu( - mixed_qkvz, mixed_ba, num_heads_qk, num_heads_v, head_k_dim, head_v_dim - ) - ) - atol = rtol = precision[mixed_qkv.dtype] - torch.testing.assert_close(mixed_qkv, mixed_qkv_ref, atol=atol, rtol=rtol) - torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) - torch.testing.assert_close(b, b_ref, atol=atol, rtol=rtol) - torch.testing.assert_close(a, a_ref, atol=atol, rtol=rtol) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_rope.py b/test/registered/cpu/test_rope.py deleted file mode 100644 index 7246214d2dac..000000000000 --- a/test/registered/cpu/test_rope.py +++ /dev/null @@ -1,285 +0,0 @@ -import unittest - -import torch -from utils import precision - -from sglang.srt.layers.rotary_embedding import ( - MRotaryEmbedding, - RotaryEmbedding, -) -from sglang.srt.layers.rotary_embedding.rope_variant import ( - DeepseekScalingRotaryEmbedding, - apply_rotary_pos_emb_native, -) -from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestROPE(CustomTestCase): - def test_mrope(self): - torch.manual_seed(100) - head_size = 128 - seq_len = 512 - num_heads = 16 - num_kv_heads = 1 - rotary_dim = 128 - max_pos = 262144 - base = 5000000 - is_neox_style = True - dtype = torch.bfloat16 - mrope_section = [24, 20, 20] - mrope_interleaved = True - positions_mrope = torch.randint(0, max_pos, (3, seq_len)) - positions_text = torch.randint(0, max_pos, (seq_len,)) - set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) - - test_config = [ - # (dtype, is_neox_stype, mrope_interleaved, positions, mrope_section) - (torch.bfloat16, False, True, positions_mrope, mrope_section), - (torch.bfloat16, False, False, positions_mrope, mrope_section), - (torch.bfloat16, False, False, positions_text, None), - (torch.bfloat16, True, True, positions_mrope, mrope_section), - (torch.bfloat16, True, False, positions_mrope, mrope_section), - (torch.bfloat16, True, False, positions_text, None), - ] - for ( - dtype, - is_neox_style, - mrope_interleaved, - positions, - mrope_section, - ) in test_config: - rope = MRotaryEmbedding( - head_size, - rotary_dim, - max_pos, - base, - is_neox_style, - dtype, - mrope_section, - mrope_interleaved, - ) - enable_autocast = True - - with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast): - q = torch.randn(seq_len, num_heads * head_size, dtype=dtype) - q_clone = q.clone() - k = torch.randn(seq_len, num_kv_heads * head_size, dtype=dtype) - k_clone = k.clone() - - # ref kernel - q_ref, k_ref = rope.forward_native( - query=q, - key=k, - positions=positions, - ) - # fused rope kernel - q_sgl, k_sgl = torch.ops.sgl_kernel.multimodal_rotary_embedding_cpu( - positions, - q_clone, - k_clone, - rope.head_size, - rope.cos_sin_cache, - rope.mrope_section, - rope.mrope_interleaved, - is_neox_style, - ) - atol = rtol = precision[q_ref.dtype] - torch.testing.assert_close(q_ref, q_sgl, atol=atol, rtol=rtol) - torch.testing.assert_close(k_ref, k_sgl, atol=atol, rtol=rtol) - - def test_deepseek_v2_rope(self): - num_head = 16 - seq_len = 1024 - q_head_dim = 192 - qk_nope_head_dim = 128 - qk_rope_head_dim = 64 - max_pos = 256 - k_dim = 576 - rotary_dim = 64 - is_neox_style = False - set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) - - # Create cos_sin_cache - freqs = torch.rand(max_pos, qk_rope_head_dim // 2) - cos = freqs.cos() * 0.7 - sin = freqs.sin() * 0.7 - cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16) - positions = torch.randint(0, max_pos, (seq_len,)) - - rope = DeepseekScalingRotaryEmbedding( - qk_rope_head_dim, - rotary_dim, - max_pos, - 16, # not used since cos_sin_cache is provided - is_neox_style, - 1.0, - torch.bfloat16, - device="cpu", - ) - rope.register_buffer("cos_sin_cache", cos_sin_cache) - - for dtype in [torch.bfloat16]: - enable_autocast = True - - with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast): - q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype) - q_clone = q.clone() - k = torch.randn(seq_len, 1, k_dim, dtype=dtype) - k_clone = k.clone() - _, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - _, q_pe_clone = q_clone.split( - [qk_nope_head_dim, qk_rope_head_dim], dim=-1 - ) - k_pe = k[:, :, k_dim - qk_rope_head_dim :] - k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :] - - # ref kernel - q_pe, k_pe = rope.forward_native( - query=q_pe, - key=k_pe, - positions=positions, - ) - - # fused rope kernel - q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu( - positions, - q_pe_clone, - k_pe_clone, - rope.head_size, - cos_sin_cache, - False, - ) - - atol = rtol = precision[q_pe.dtype] - torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol) - torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol) - torch.testing.assert_close(k_pe, k_pe_clone) - - def test_origin_rope(self): - def single_test( - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - dims: int, - is_neox_style: bool, - dtype: torch.dtype, - device: str, - batch_size: int, - seq_len: int, - num_q_heads: int, - num_kv_heads: int, - ): - set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) - torch.manual_seed(100) - rope_ref = RotaryEmbedding( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - ).to(device) - pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) - query = torch.randn( - batch_size * seq_len, - num_q_heads * head_size, - dtype=dtype, - device=device, - ) - key = torch.randn( - batch_size * seq_len, - num_kv_heads * head_size, - dtype=dtype, - device=device, - ) - if dims == 4: - query = query.view(batch_size, seq_len, num_q_heads, head_size) - key = key.view(batch_size, seq_len, num_kv_heads, head_size) - query_ref, key_ref = query.clone(), key.clone() - query_cpu, key_cpu = query.clone(), key.clone() - - query_ref_out, key_ref_out = rope_ref.forward_native( - pos_ids, query_ref, key_ref - ) - query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu( - pos_ids, - query_cpu, - key_cpu, - rope_ref.head_size, - rope_ref.cos_sin_cache.to(query.dtype), - rope_ref.is_neox_style, - ) - torch.testing.assert_close( - query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2 - ) - torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2) - - test_config = [ - (64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1), - (256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8), - (512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2), - (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8), - (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4), - (512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2), - ] - - for ( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - device, - batch_size, - seq_len, - num_q_heads, - num_kv_heads, - ) in test_config: - for dim in [2, 4]: - single_test( - head_size, - rotary_dim, - max_position_embeddings, - base, - dim, - is_neox_style, - dtype, - device, - batch_size, - seq_len, - num_q_heads, - num_kv_heads, - ) - - def test_apply_rotary_pos_emb(self): - num_tokens = 1024 - num_heads = 8 - head_size = 72 - qkv = torch.randn(num_tokens, num_heads * head_size * 3).to(torch.bfloat16) - query, key, _ = qkv.split( - [num_heads * head_size, num_heads * head_size, num_heads * head_size], - dim=-1, - ) - query = query.view(num_tokens, num_heads, head_size) - key = key.view(num_tokens, num_heads, head_size) - for sincos_dtype in [torch.float32, torch.bfloat16]: - cos = torch.rand(num_tokens, head_size).to(sincos_dtype) - sin = torch.rand(num_tokens, head_size).to(sincos_dtype) - q_out_ref, k_out_ref = apply_rotary_pos_emb_native(query, key, cos, sin) - q_out_sgl, k_out_sgl = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu( - query, key, cos, sin - ) - torch.testing.assert_close(q_out_ref, q_out_sgl, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(k_out_ref, k_out_sgl, atol=1e-2, rtol=1e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_server_args_backend.py b/test/registered/cpu/test_server_args_backend.py deleted file mode 100644 index 22448a7f471d..000000000000 --- a/test/registered/cpu/test_server_args_backend.py +++ /dev/null @@ -1,38 +0,0 @@ -import unittest -from unittest.mock import patch - -from sglang.srt.server_args import ServerArgs -from sglang.test.ci.ci_register import register_cpu_ci - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - - -class TestServerArgsCPUBackend(unittest.TestCase): - def _make_server_args(self, attention_backend=None): - server_args = ServerArgs.__new__(ServerArgs) - server_args.device = "cpu" - server_args.attention_backend = attention_backend - server_args.sampling_backend = None - return server_args - - @patch("sglang.srt.server_args.is_host_cpu_arm64", return_value=True) - def test_arm_cpu_defaults_to_torch_native(self, _mock_is_arm64): - server_args = self._make_server_args() - - ServerArgs._handle_cpu_backends(server_args) - - self.assertEqual(server_args.attention_backend, "torch_native") - self.assertEqual(server_args.sampling_backend, "pytorch") - - @patch("sglang.srt.server_args.is_host_cpu_arm64", return_value=False) - def test_x86_cpu_defaults_to_intel_amx(self, _mock_is_arm64): - server_args = self._make_server_args() - - ServerArgs._handle_cpu_backends(server_args) - - self.assertEqual(server_args.attention_backend, "intel_amx") - self.assertEqual(server_args.sampling_backend, "pytorch") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_shared_expert.py b/test/registered/cpu/test_shared_expert.py deleted file mode 100644 index 18b3cef78acc..000000000000 --- a/test/registered/cpu/test_shared_expert.py +++ /dev/null @@ -1,235 +0,0 @@ -import itertools -import math -import unittest - -import torch -from utils import ( - BLOCK_K, - BLOCK_N, - factor_for_scale, - fp8_max, - fp8_min, - per_token_quant_int8, - precision, - scaled_weight, - torch_naive_moe, - torch_w8a8_per_column_moe, -) - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -class TestSharedExpert(CustomTestCase): - M = [2, 121] - N = [32, 32 * 4] - K = [32, 32 * 2] - routed_scaling_factor = [16] - apply_scaling_factor = [True, False] - - M_fp8 = [2, 12] - N_fp8 = [512] - K_fp8 = [256] - - def _bf16_shared_expert(self, m, n, k, routed_scaling_factor, apply_scaling_factor): - dtype = torch.bfloat16 - - hidden_states = torch.randn(m, k, dtype=dtype) / k - w1 = torch.randn(2 * n, k, dtype=dtype) - w2 = torch.randn(k, n, dtype=dtype) - fused_output = ( - torch.randn(m, k, dtype=dtype) / k if apply_scaling_factor else None - ) - routed_scaling_factor = routed_scaling_factor if apply_scaling_factor else None - - # fused moe mutates content in hs - hidden_states2 = hidden_states.clone() - - # bfloat16 - ref = torch_naive_moe( - hidden_states, - w1, - w2, - fused_output, - routed_scaling_factor, - output_dtype=dtype, - ) - out = torch.ops.sgl_kernel.shared_expert_cpu( - hidden_states2, - w1, - w2, - fused_output, - routed_scaling_factor, - True, - False, - False, - None, - None, - None, - False, - ) - - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) - - def test_bf16_shared_expert(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.routed_scaling_factor, - self.apply_scaling_factor, - ): - with self.subTest( - m=params[0], - n=params[1], - k=params[2], - routed_scaling_factor=params[3], - apply_scaling_factor=params[4], - ): - self._bf16_shared_expert(*params) - - def _int8_shared_expert(self, m, n, k, routed_scaling_factor, apply_scaling_factor): - dtype = torch.bfloat16 - - hidden_states = torch.randn(m, k, dtype=dtype) / k - w1 = torch.randn(2 * n, k, dtype=dtype) - w2 = torch.randn(k, n, dtype=dtype) - fused_output = ( - torch.randn(m, k, dtype=dtype) / k if apply_scaling_factor else None - ) - routed_scaling_factor = routed_scaling_factor if apply_scaling_factor else None - - # fused moe mutates content in hs - hidden_states2 = hidden_states.clone() - - w1_q, w1_s = per_token_quant_int8(w1) - w2_q, w2_s = per_token_quant_int8(w2) - ref = torch_w8a8_per_column_moe( - hidden_states, - w1_q, - w2_q, - w1_s, - w2_s, - fused_output, - routed_scaling_factor, - ) - out = torch.ops.sgl_kernel.shared_expert_cpu( - hidden_states2, - w1_q, - w2_q, - fused_output, - routed_scaling_factor, - True, - True, - False, - w1_s, - w2_s, - None, - False, - ) - - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) - - def test_int8_shared_expert(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.routed_scaling_factor, - self.apply_scaling_factor, - ): - with self.subTest( - m=params[0], - n=params[1], - k=params[2], - routed_scaling_factor=params[3], - apply_scaling_factor=params[4], - ): - self._int8_shared_expert(*params) - - def _fp8_shared_expert(self, m, n, k, routed_scaling_factor, apply_scaling_factor): - dtype = torch.bfloat16 - - hidden_states = torch.randn(m, k, dtype=dtype) / math.sqrt(k) - - w1_fp32 = torch.randn(1, 2 * n, k) - w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - w2_fp32 = torch.randn(1, k, n) - w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - w1s = torch.randn(1, 2 * n // BLOCK_N, k // BLOCK_K) * factor_for_scale - w2s = torch.randn(1, k // BLOCK_N, n // BLOCK_K) * factor_for_scale - - w1_scaled = scaled_weight(w1, w1s).view(2 * n, k) - w2_scaled = scaled_weight(w2, w2s).view(k, n) - - # change back to 2D - w1, w2 = w1.squeeze(0), w2.squeeze(0) - w1s, w2s = w1s.squeeze(0), w2s.squeeze(0) - w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0) - - fused_output = ( - torch.randn(m, k, dtype=dtype) / math.sqrt(k) - if apply_scaling_factor - else None - ) - routed_scaling_factor = routed_scaling_factor if apply_scaling_factor else None - hidden_states2 = hidden_states.clone() - - # ref with bfloat16 - ref = torch_naive_moe( - hidden_states, - w1_scaled, - w2_scaled, - fused_output, - routed_scaling_factor, - output_dtype=dtype, - ) - - w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K] - w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N] - out = torch.ops.sgl_kernel.shared_expert_cpu( - hidden_states2, - w1, - w2, - fused_output, - routed_scaling_factor, - True, - False, - True, - w1s, - w2s, - [BLOCK_N, BLOCK_K], - True, - ) - - atol = rtol = precision[ref.dtype] - torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) - - def test_fp8_shared_expert(self): - for params in itertools.product( - self.M_fp8, - self.N_fp8, - self.K_fp8, - self.routed_scaling_factor, - self.apply_scaling_factor, - ): - with self.subTest( - m=params[0], - n=params[1], - k=params[2], - routed_scaling_factor=params[3], - apply_scaling_factor=params[4], - ): - self._fp8_shared_expert(*params) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/test_topk.py b/test/registered/cpu/test_topk.py deleted file mode 100644 index 2d6d2bbbeb0f..000000000000 --- a/test/registered/cpu/test_topk.py +++ /dev/null @@ -1,223 +0,0 @@ -import unittest - -import torch - -from sglang.srt.layers.moe.topk import ( - biased_grouped_topk_impl as native_biased_grouped_topk, -) -from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk -from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk -from sglang.srt.models.llama4 import Llama4MoE -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase - -register_cpu_ci(est_time=10, suite="stage-b-test-cpu") - -torch.manual_seed(1234) - - -# This is used by the Deepseek-V2 model -class TestGroupedTopK(CustomTestCase): - def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): - torch.manual_seed(1234) - - # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating - hidden_states = torch.randn(M, 100, dtype=dtype) - gating_output = torch.randn(M, E, dtype=dtype) * 2 * M - - ref_topk_weights, ref_topk_ids = native_grouped_topk( - hidden_states.float(), - gating_output.float(), - topk, - renormalize, - G, - topk_group, - ) - - # fused version - topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu( - hidden_states, - gating_output, - topk, - renormalize, - G, - topk_group, - 0, - None, - None, - ) - - res = torch.zeros(M, E, dtype=torch.float) - ref = torch.zeros(M, E, dtype=torch.float) - res.scatter_(1, topk_ids.long(), topk_weights) - ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) - torch.testing.assert_close(res, ref) - - def test_grouped_topk(self): - for renormalize in [True, False]: - self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16) - self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16) - self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16) - self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16) - self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16) - self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16) - self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16) - - -# DeepSeek V2/V3/R1 uses biased_grouped_top -class TestBiasedGroupedTopK(CustomTestCase): - def _run_single_test( - self, - M, - E, - G, - topk, - topk_group, - renormalize, - gating_dtype, - bias_dtype, - routed_scaling_factor, - ): - torch.manual_seed(1024) - - # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating - hidden_states = torch.randn(M, 100, dtype=torch.bfloat16) - gating_output = torch.randn(M, E, dtype=gating_dtype) * 2 * M - correction_bias = torch.randn(E, dtype=bias_dtype) - - ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( - hidden_states.float(), - gating_output.float(), - correction_bias.float(), - topk, - renormalize, - G, - topk_group, - ) - ref_topk_weights = ( - ref_topk_weights * routed_scaling_factor - if routed_scaling_factor is not None - else ref_topk_weights - ) - # fused version - topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu( - hidden_states, - gating_output, - correction_bias, - topk, - renormalize, - G, - topk_group, - 0, - routed_scaling_factor, - None, - ) - - res = torch.zeros(M, E, dtype=torch.float) - ref = torch.zeros(M, E, dtype=torch.float) - res.scatter_(1, topk_ids.long(), topk_weights) - ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) - torch.testing.assert_close(res, ref) - - def test_biased_grouped_topk(self): - for renormalize in [False]: - for bias_dtype in [torch.float32, torch.bfloat16]: - for gating_dtype in [torch.float32, torch.bfloat16]: - for routed_scaling_factor in [None, 1.125]: - for E_num in [128, 192, 256, 384]: - self._run_single_test( - 34, - E_num, - 8, - 8, - 2, - renormalize, - gating_dtype, - bias_dtype, - routed_scaling_factor, - ) - - -class TestTopK(CustomTestCase): - def _run_single_test(self, M, E, topk, renormalize, dtype): - torch.manual_seed(1998) - - # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating - hidden_states = torch.randn(M, 100, dtype=dtype) - gating_output = torch.randn(M, E, dtype=dtype) * 2 * M - - ref_topk_weights, ref_topk_ids = native_fused_topk( - hidden_states.float(), - gating_output.float(), - topk, - renormalize, - ) - - # fused version - topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( - hidden_states, gating_output, topk, renormalize - ) - - res = torch.zeros(M, E, dtype=torch.float) - ref = torch.zeros(M, E, dtype=torch.float) - res.scatter_(1, topk_ids.long(), topk_weights) - ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) - torch.testing.assert_close(res, ref) - - def test_topk(self): - for renormalize in [True, False]: - self._run_single_test(123, 8, 2, renormalize, torch.bfloat16) - self._run_single_test(123, 16, 3, renormalize, torch.bfloat16) - self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) - self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) - self._run_single_test(123, 64, 6, renormalize, torch.bfloat16) - self._run_single_test(123, 256, 4, renormalize, torch.bfloat16) - self._run_single_test(123, 160, 6, renormalize, torch.bfloat16) - - -class TestCustomTopK(CustomTestCase): - def _run_single_test( - self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f - ): - torch.manual_seed(16) - - # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating - hidden_states = torch.randn(M, 100, dtype=dtype) - gating_output = torch.randn(M, E, dtype=dtype) * 2 * M - - ref_topk_weights, ref_topk_ids = native_custom_f( - hidden_states.float(), - gating_output.float(), - topk, - renormalize, - ) - - # fused version - topk_weights, topk_ids = fused_custom_f( - hidden_states, gating_output, topk, renormalize - ) - - res = torch.zeros(M, E, dtype=torch.float) - ref = torch.zeros(M, E, dtype=torch.float) - res.scatter_(1, topk_ids.long(), topk_weights) - ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) - torch.testing.assert_close(res, ref) - - def test_custom_topk(self): - test_custom_functions = [ - (Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu) - ] - for native_custom_f, fused_custom_f in test_custom_functions: - self._run_single_test( - 123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f - ) - self._run_single_test( - 123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f - ) - self._run_single_test( - 123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/cpu/utils.py b/test/registered/cpu/utils.py deleted file mode 100644 index 57e7e74c103c..000000000000 --- a/test/registered/cpu/utils.py +++ /dev/null @@ -1,440 +0,0 @@ -import itertools -import math - -import torch -import torch.nn.functional as F - -precision = { - torch.bfloat16: 1e-2, - torch.float16: 1e-3, - torch.float32: 1e-5, -} - - -BLOCK_N, BLOCK_K = 64, 128 -factor_for_scale = 1e-3 -fp8_max, fp8_min = 400, -400 - - -def parametrize(**params): - def decorator(func): - def wrapper(self): - for combo in itertools.product(*params.values()): - kwargs = dict(zip(params.keys(), combo)) - with self.subTest(**kwargs): - func(self, **kwargs) - - return wrapper - - return decorator - - -def SiluAndMul(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] - - -def GeluAndMul(x: torch.Tensor, approximate="tanh") -> torch.Tensor: - d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] - - -def per_token_quant_int8(x): - x = x.float() - absmax = x.abs().max(dim=-1).values - absmax = absmax.clamp_min(1e-10).unsqueeze(-1) - scale_x = absmax / 127 - x_q = x.mul(127 / absmax) - x_q = torch.round(x_q).to(torch.int8) - - return x_q, scale_x - - -def convert_weight(weight, scale_block_size, A_dtype): - N, K = weight.size() - fp8_max = 448.0 - scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128) - - pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N - pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K - - if pad_N > 0 or pad_K > 0: - weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N)) - - weight_blocks = weight.view( - math.ceil(N / scale_block_size_N), - scale_block_size_N, - math.ceil(K / scale_block_size_K), - scale_block_size_K, - ) # (8, 128, 8, 128) - weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128) - - # Step 2: compute per-block max abs values → scale - abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1) - scales = abs_max / fp8_max - scales = torch.where( - scales == 0, torch.ones_like(scales), scales - ) # avoid division by zero - - q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn) - q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous() - - if pad_N > 0 or pad_K > 0: - q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K) - q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous() - else: - q_fp8_reshape = q_fp8_reshape.view(N, K) - - dq_weight = q_fp8.float() * scales - dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128) - - if pad_N > 0 or pad_K > 0: - w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype) - w_dq = w_dq[:N, :K].contiguous() - else: - w_dq = dq_weight.view(N, K).to(A_dtype) - - scales = scales.view( - math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K) - ) - - return q_fp8_reshape, scales, w_dq - - -def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16): - """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" - A = A.to(torch.float32) - B = B.to(torch.float32) - - assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" - - # Reshape input - M = A.numel() // A.shape[-1] - B = B.t() # Transpose weight matrix - N, K = B.shape - origin_C_shape = A.shape[:-1] + (K,) - A = A.reshape(M, N) - - # As is per-token [M, 1], Bs is per-column [1, K] - C = torch.matmul(A, B) # [M, K] - C = As * C * Bs.view(1, -1) # Broadcast per-column scale - - if bias is not None: - C.add_(bias.view(1, -1)) - - return C.reshape(origin_C_shape).to(output_dtype) - - -def torch_naive_moe(a, w1, w2, b, routed_scaling_factor, output_dtype=torch.bfloat16): - - a = a.to(torch.float32) - w1 = w1.to(torch.float32) - w2 = w2.to(torch.float32) - b = b.to(torch.float32) if b is not None else None - - ic1 = torch.matmul(a, w1.transpose(0, 1)) - ic2 = SiluAndMul(ic1) - ic3 = torch.matmul(ic2, w2.transpose(0, 1)) - - out = ic3 if b is None else ic3 + b * routed_scaling_factor - - return out.to(output_dtype) - - -def torch_w8a8_per_column_moe( - a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor, output_dtype=torch.bfloat16 -): - - a = a.to(torch.float32) - b = b.to(torch.float32) if b is not None else None - - # Perform per-token quantization - a_q, a_s = per_token_quant_int8(a) - - ic1 = native_w8a8_per_token_matmul( - a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32 - ) - ic2 = SiluAndMul(ic1) - - a1_q, a1_s = per_token_quant_int8(ic2) - ic3 = native_w8a8_per_token_matmul( - a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32 - ) - - out = ic3 if b is None else ic3 + b * routed_scaling_factor - - return out.to(output_dtype) - - -def scaled_weight(weight, scales): - E, N, K = weight.shape - pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N - pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K - - if pad_N > 0 or pad_K > 0: - weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N)) - - weight_block = ( - weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K) - .permute(0, 1, 3, 2, 4) - .float() - .contiguous() - ) - - weight_scaled = ( - ( - weight_block - * scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1) - ) - .permute(0, 1, 3, 2, 4) - .contiguous() - ) - if pad_N > 0 or pad_K > 0: - weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K) - weight_scaled = weight_scaled[..., :N, :K].contiguous() - else: - weight_scaled = weight_scaled.view(E, N, K) - return weight_scaled - - -def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - - if renormalize: - topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) - - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( - 0, 1 - ) - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - - -def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk): - """This function performs fused moe with per-column int8 quantization using native torch.""" - - B, D = a.shape - # Perform per-token quantization - a_q, a_s = per_token_quant_int8(a) - # Repeat tokens to match topk - a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - # Also repeat the scale - a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] - - out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device) - - # Calculate routing - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - # Process each expert - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - # First MLP layer: note that a_s is now per-token - inter_out = native_w8a8_per_token_matmul( - a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - bias=None, - output_dtype=torch.float32, - ) - # Activation function - act_out = SiluAndMul(inter_out) - # Quantize activation output with per-token - act_out_q, act_out_s = per_token_quant_int8(act_out) - # Second MLP layer - out[mask] = native_w8a8_per_token_matmul( - act_out_q, - w2[i], - act_out_s, - w2_s[i], - bias=None, - output_dtype=torch.float32, - ) - # Apply routing weights and sum - return ( - (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)) - .sum(dim=1) - .to(a.dtype) - ) - - -def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float() - out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device) - - # Calculate routing - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1)) - ic1 = SiluAndMul(ic0) - out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1)) - - return ( - (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)) - .sum(dim=1) - .to(a.dtype) - ) - - -def make_non_contiguous(x: torch.Tensor) -> torch.Tensor: - """ - Make a tensor non-contiguous by slicing it via last dimension. - """ - last_dim = x.shape[-1] - return x[..., : last_dim // 2] if x.is_contiguous() else x - - -def awq_reverse_reorder_int_tensor(int_tensor, bits: int): - assert bits == 4 - - int_tensor = int_tensor.T.contiguous() - compress_ratio = 32 // bits - assert int_tensor.shape[-1] % compress_ratio == 0 - - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - order_tensor = torch.tensor( - order_map, dtype=torch.int32, device=int_tensor.device - ).reshape(1, -1) - order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1) - order_tensor = order_tensor + torch.arange( - 0, - int_tensor.shape[1], - compress_ratio, - dtype=torch.int32, - device=int_tensor.device, - ).reshape(-1, 1) - order_tensor = order_tensor.reshape(-1) - - reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor] - reverse_order_tensor = reverse_order_tensor[order_tensor] - int_tensor = int_tensor[:, reverse_order_tensor] - return int_tensor - - -def unpack_and_dequant_awq( - awq_qweight: torch.Tensor, - awq_qzeros: torch.Tensor, - awq_scales: torch.Tensor, - bits: int, - group_size: int, -): - """ - Args: - awq_qweight (`torch.LongTensor`): - Expected shape: (in_features, out_features // (32 // bits)) - awq_qzeros (`torch.LongTensor`): - Expected shape: (in_features // group_size, out_features // (32 // bits)) - awq_scales (`torch.LongTensor`): - Expected shape: (in_features // group_size, out_features) - - Returns: - fp16_weight (`torch.LongTensor`): - With shape (in_features, out_features). - zeros (`torch.LongTensor`): - With shape (in_features // group_size, out_features). - """ - assert bits == 4 - - qzeros = awq_qzeros - qweight = awq_qweight - qweight = qweight.T.contiguous() - - scales = awq_scales - scales = scales.reshape(-1, 1, scales.shape[-1]) - - infeatures = awq_qweight.shape[0] - - wf = torch.tensor( - list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device - ).unsqueeze(0) - zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to( - torch.int16 if bits == 8 else torch.int8 - ) - - torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) - - zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) - - weight = torch.bitwise_right_shift( - torch.unsqueeze(qweight, 1), wf.unsqueeze(-1) - ).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(weight, (2**bits) - 1, out=weight) - weight = weight.reshape(-1, group_size, weight.shape[2]) - - weight = weight.view(-1, weight.shape[-1]) - zeros = zeros.view(-1, zeros.shape[-1]) - - zeros = zeros.T.contiguous() - zeros = awq_reverse_reorder_int_tensor(zeros, bits) - weight = awq_reverse_reorder_int_tensor(weight, bits) - - # Dequantize weights. - scales = awq_scales - zeros = zeros.contiguous() - scale_zeros = zeros * scales - - g_idx = torch.tensor( - [i // group_size for i in range(infeatures)], dtype=torch.int32 - ) - scale_mat = scales[g_idx] - scale_zeros_mat = scale_zeros[g_idx].to(torch.bfloat16) - - qdq_weight_T = weight * scale_mat - scale_zeros_mat.to(torch.bfloat16) - - fp16_weight = qdq_weight_T.T - - return fp16_weight, zeros - - -def unpack_4bit_to_32bit_signed(qweight, qzeros): - # Unpack 4-bit values and interpret them as signed integers - unpacked_weights = torch.zeros( - (qweight.shape[0] * 8, qweight.shape[1]), - dtype=torch.int8, - device=qweight.device, - requires_grad=False, - ) - unpacked_zeros = torch.zeros( - (qzeros.shape[0], qzeros.shape[1] * 8), - dtype=torch.int8, - device=qzeros.device, - requires_grad=False, - ) - - for row in range(unpacked_weights.shape[0]): - i = row % 8 - unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF - - for col in range(unpacked_zeros.shape[1]): - i = col % 8 - unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF - - return unpacked_weights, unpacked_zeros + 1 - - -def unpack_and_dequant_gptq(qweight, qzeros, scales): - unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) - group_size = unpacked_qweight.shape[0] // scales.shape[0] - scales = scales.repeat_interleave(group_size, dim=0) - unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) - unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales - - return unpacked_qweight.T diff --git a/test/run_suite.py b/test/run_suite.py index e38546c66f3e..c09b1fae8fc9 100644 --- a/test/run_suite.py +++ b/test/run_suite.py @@ -23,7 +23,7 @@ # Per-commit test suites (run on every PR) PER_COMMIT_SUITES = { - HWBackend.CPU: ["stage-a-test-cpu", "stage-b-test-cpu"], + HWBackend.CPU: ["stage-a-test-cpu"], HWBackend.AMD: [ "stage-a-test-1-gpu-small-amd", "stage-b-test-1-gpu-small-amd", @@ -238,9 +238,7 @@ def run_a_suite(args): for f in glob.glob( os.path.join(script_dir, "registered", "**", "*.py"), recursive=True ) - if not f.endswith("/conftest.py") - and not f.endswith("/__init__.py") - and not f.endswith("/cpu/utils.py") + if not f.endswith("/conftest.py") and not f.endswith("/__init__.py") ] # JIT kernel tests and benchmarks (live alongside kernel source)