From 5d121f8ba75d3b083dd005dcc48864fd791a3b53 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 24 Oct 2025 14:48:05 -0400 Subject: [PATCH 001/150] Implement basic test --- test/srt/layers/moe/test_moe_runners.py | 52 +++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 test/srt/layers/moe/test_moe_runners.py diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py new file mode 100644 index 000000000000..5a5119560465 --- /dev/null +++ b/test/srt/layers/moe/test_moe_runners.py @@ -0,0 +1,52 @@ +# Minimal integration smoke test for launching a server with the MoE Triton runner. + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestMoERunnerTriton(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--moe-runner-backend", + "triton", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_moe_runner_smoke(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=1, + num_threads=1, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["score"], 0.0) + + +if __name__ == "__main__": + unittest.main() From 317c4f60f60c445f232da7d7bfc017c7e12657ef Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 24 Oct 2025 15:01:55 -0400 Subject: [PATCH 002/150] Shard the model across two gpus --- test/srt/layers/moe/test_moe_runners.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 5a5119560465..80f175031fdc 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -27,6 +27,8 @@ def setUpClass(cls): "--trust-remote-code", "--moe-runner-backend", "triton", + "--tp", + "2", ], ) From da4285faa5f8366a072c943c17b4883c7154a0f1 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 24 Oct 2025 16:04:08 -0400 Subject: [PATCH 003/150] Add multi-test support --- test/srt/layers/moe/test_moe_runners.py | 79 +++++++++++++++++-------- 1 file changed, 53 insertions(+), 26 deletions(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 80f175031fdc..55444872c6d1 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -1,5 +1,3 @@ -# Minimal integration smoke test for launching a server with the MoE Triton runner. - import unittest from types import SimpleNamespace @@ -15,39 +13,68 @@ class TestMoERunnerTriton(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ + BASE_URL = DEFAULT_URL_FOR_TEST + TIMEOUT = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + DEFAULT_MODEL = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT + DEFAULT_EVAL_KWARGS = { + "eval_name": "mmlu", + "num_examples": 5, + "num_threads": 1, + } + + CONFIGS = { + "tp2_torch_native": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ "--trust-remote-code", "--moe-runner-backend", "triton", "--tp", "2", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + "--max-total-tokens", + "2048", ], - ) + }, + } + + def _run_config(self, config: dict) -> None: + model = config.get("model", self.DEFAULT_MODEL) + other_args = config.get("other_args", []) + eval_kwargs = self.DEFAULT_EVAL_KWARGS | config.get("eval_kwargs", {}) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_moe_runner_smoke(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=1, - num_threads=1, + process = popen_launch_server( + model, + self.BASE_URL, + timeout=self.TIMEOUT, + other_args=other_args, ) + try: + args = SimpleNamespace( + base_url=self.BASE_URL, + model=model, + **eval_kwargs, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["score"], 0.0) + finally: + kill_process_tree(process.pid) + + +def _make_test(config_name: str, config: dict): + def test(self): + self._run_config(config) + + test.__name__ = f"test_{config_name}" + return test + - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual(metrics["score"], 0.0) +for _name, _config in TestMoERunnerTriton.CONFIGS.items(): + setattr(TestMoERunnerTriton, f"test_{_name}", _make_test(_name, _config)) if __name__ == "__main__": From 4008c4261aab8cc7f623067c22c5a609f1b97817 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 24 Oct 2025 16:31:21 -0400 Subject: [PATCH 004/150] Add comprehensive test configs --- test/srt/layers/moe/test_moe_runners.py | 142 ++++++++++++++++++++---- 1 file changed, 122 insertions(+), 20 deletions(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 55444872c6d1..45107e67ccf0 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -1,9 +1,13 @@ +import os import unittest from types import SimpleNamespace from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( + DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -23,7 +27,7 @@ class TestMoERunnerTriton(CustomTestCase): } CONFIGS = { - "tp2_torch_native": { + "moe_runner_triton_unquant_standard": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", @@ -31,38 +35,136 @@ class TestMoERunnerTriton(CustomTestCase): "triton", "--tp", "2", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", "--max-total-tokens", "2048", ], }, + "moe_runner_triton_kernel_unquant_standard": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "triton_kernel", + "--tp", + "2", + "--max-total-tokens", + "2048", + ], + }, + "moe_runner_deep_gemm_awq_quantization": { + "model": DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "deep_gemm", + "--quantization", + "awq", + "--tp", + "2", + "--max-total-tokens", + "2048", + ], + "eval_kwargs": {"num_examples": 3}, + }, + "moe_runner_flashinfer_trtllm_mxfp4_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_trtllm", + "--quantization", + "mxfp4", + "--tp", + "2", + "--max-total-tokens", + "2048", + ], + "eval_kwargs": {"num_examples": 2}, + }, + "moe_runner_flashinfer_mxfp4_fp8_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--quantization", + "mxfp4", + "--tp", + "2", + "--max-total-tokens", + "2048", + ], + "eval_kwargs": {"num_examples": 2}, + }, + "moe_runner_flashinfer_cutedsl_fp8_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_cutedsl", + "--quantization", + "fp8", + "--tp", + "2", + "--max-total-tokens", + "2048", + ], + "eval_kwargs": {"num_examples": 2}, + }, + "moe_runner_cutlass_w8a8_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "cutlass", + "--quantization", + "w8a8", + "--tp", + "2", + "--max-total-tokens", + "2048", + ], + "eval_kwargs": {"num_examples": 2}, + }, } def _run_config(self, config: dict) -> None: model = config.get("model", self.DEFAULT_MODEL) other_args = config.get("other_args", []) eval_kwargs = self.DEFAULT_EVAL_KWARGS | config.get("eval_kwargs", {}) + env_overrides = config.get("env", {}) or {} - process = popen_launch_server( - model, - self.BASE_URL, - timeout=self.TIMEOUT, - other_args=other_args, - ) + saved_env = {} try: - args = SimpleNamespace( - base_url=self.BASE_URL, - model=model, - **eval_kwargs, - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual(metrics["score"], 0.0) + for key, value in env_overrides.items(): + saved_env[key] = os.environ.get(key) + os.environ[key] = value + + process = None + try: + process = popen_launch_server( + model, + self.BASE_URL, + timeout=self.TIMEOUT, + other_args=other_args, + ) + args = SimpleNamespace( + base_url=self.BASE_URL, + model=model, + **eval_kwargs, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["score"], 0.0) + finally: + if process is not None: + kill_process_tree(process.pid) finally: - kill_process_tree(process.pid) + for key, previous in saved_env.items(): + if previous is None: + os.environ.pop(key, None) + else: + os.environ[key] = previous def _make_test(config_name: str, config: dict): From 30a39866d53caf244a5f13abaccfd58bdee21756 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 24 Oct 2025 16:58:29 -0400 Subject: [PATCH 005/150] Add comprehensive test configs --- test/srt/layers/moe/test_moe_runners.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 45107e67ccf0..d62061c1aa98 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -37,6 +37,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], }, "moe_runner_triton_kernel_unquant_standard": { @@ -49,6 +51,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], }, "moe_runner_deep_gemm_awq_quantization": { @@ -63,6 +67,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], "eval_kwargs": {"num_examples": 3}, }, @@ -78,6 +84,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], "eval_kwargs": {"num_examples": 2}, }, @@ -93,6 +101,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], "eval_kwargs": {"num_examples": 2}, }, @@ -108,6 +118,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], "eval_kwargs": {"num_examples": 2}, }, @@ -123,6 +135,8 @@ class TestMoERunnerTriton(CustomTestCase): "2", "--max-total-tokens", "2048", + "--mem-fraction-static", + "0.95", ], "eval_kwargs": {"num_examples": 2}, }, From 121fcae110d00599a65ae2f37b550fc4fa168d58 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 24 Oct 2025 17:07:07 -0400 Subject: [PATCH 006/150] Rename moe test file --- .../moe/{test_moe_runners.py => test_moe.py} | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) rename test/srt/layers/moe/{test_moe_runners.py => test_moe.py} (85%) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe.py similarity index 85% rename from test/srt/layers/moe/test_moe_runners.py rename to test/srt/layers/moe/test_moe.py index d62061c1aa98..c073049d36f3 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe.py @@ -146,34 +146,28 @@ def _run_config(self, config: dict) -> None: model = config.get("model", self.DEFAULT_MODEL) other_args = config.get("other_args", []) eval_kwargs = self.DEFAULT_EVAL_KWARGS | config.get("eval_kwargs", {}) - env_overrides = config.get("env", {}) or {} - saved_env = {} - try: - for key, value in env_overrides.items(): - saved_env[key] = os.environ.get(key) - os.environ[key] = value + env_overrides = config.get("env", {}) + saved_env = {k: os.environ.get(k) for k in env_overrides} + os.environ.update(env_overrides) - process = None - try: - process = popen_launch_server( - model, - self.BASE_URL, - timeout=self.TIMEOUT, - other_args=other_args, - ) - args = SimpleNamespace( - base_url=self.BASE_URL, - model=model, - **eval_kwargs, - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual(metrics["score"], 0.0) - finally: - if process is not None: - kill_process_tree(process.pid) + process = popen_launch_server( + model, + self.BASE_URL, + timeout=self.TIMEOUT, + other_args=other_args, + ) + try: + args = SimpleNamespace( + base_url=self.BASE_URL, + model=model, + **eval_kwargs, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["score"], 0.0) finally: + kill_process_tree(process.pid) for key, previous in saved_env.items(): if previous is None: os.environ.pop(key, None) From d57c7a0e709464aaac494ae5e30c45a7ffdd12eb Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 26 Oct 2025 17:01:25 -0400 Subject: [PATCH 007/150] Add spec decoding cases --- test/srt/layers/moe/test_moe.py | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py index c073049d36f3..5e7d5b9511e8 100644 --- a/test/srt/layers/moe/test_moe.py +++ b/test/srt/layers/moe/test_moe.py @@ -28,6 +28,21 @@ class TestMoERunnerTriton(CustomTestCase): CONFIGS = { "moe_runner_triton_unquant_standard": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "triton_kernel", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + ], + }, + # Speculative decoding (EAGLE) on unquantized small MoE with Triton + "moe_runner_triton_unquant_spec_eagle": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", @@ -39,7 +54,20 @@ class TestMoERunnerTriton(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", ], + "eval_kwargs": {"num_examples": 3}, + "env": { + "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", + "SGLANG_ENABLE_JIT_DEEPGEMM": "0", + }, }, "moe_runner_triton_kernel_unquant_standard": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, @@ -89,6 +117,32 @@ class TestMoERunnerTriton(CustomTestCase): ], "eval_kwargs": {"num_examples": 2}, }, + # Speculative decoding (NGRAM) with FlashInfer MXFP4 backend (differs from main path) + "moe_runner_flashinfer_mxfp4_quantization_spec_ngram": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--quantization", + "mxfp4", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--speculative-algorithm", + "NGRAM", + "--speculative-num-draft-tokens", + "8", + ], + "eval_kwargs": {"num_examples": 2}, + "env": { + "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", + "SGLANG_ENABLE_JIT_DEEPGEMM": "0", + }, + }, "moe_runner_flashinfer_mxfp4_fp8_quantization": { "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, "other_args": [ From 54c0d013ba3207be37c142e29e4165d6a5786ec0 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 27 Oct 2025 21:07:30 -0400 Subject: [PATCH 008/150] Simplify code --- test/srt/layers/moe/test_moe.py | 48 +++++++-------------------------- 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py index 5e7d5b9511e8..e1aa55b780d0 100644 --- a/test/srt/layers/moe/test_moe.py +++ b/test/srt/layers/moe/test_moe.py @@ -1,4 +1,3 @@ -import os import unittest from types import SimpleNamespace @@ -16,10 +15,9 @@ ) -class TestMoERunnerTriton(CustomTestCase): +class TestMoERunner(CustomTestCase): BASE_URL = DEFAULT_URL_FOR_TEST TIMEOUT = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - DEFAULT_MODEL = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT DEFAULT_EVAL_KWARGS = { "eval_name": "mmlu", "num_examples": 5, @@ -63,11 +61,6 @@ class TestMoERunnerTriton(CustomTestCase): "--speculative-num-draft-tokens", "2", ], - "eval_kwargs": {"num_examples": 3}, - "env": { - "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", - "SGLANG_ENABLE_JIT_DEEPGEMM": "0", - }, }, "moe_runner_triton_kernel_unquant_standard": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, @@ -98,7 +91,6 @@ class TestMoERunnerTriton(CustomTestCase): "--mem-fraction-static", "0.95", ], - "eval_kwargs": {"num_examples": 3}, }, "moe_runner_flashinfer_trtllm_mxfp4_quantization": { "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, @@ -115,7 +107,6 @@ class TestMoERunnerTriton(CustomTestCase): "--mem-fraction-static", "0.95", ], - "eval_kwargs": {"num_examples": 2}, }, # Speculative decoding (NGRAM) with FlashInfer MXFP4 backend (differs from main path) "moe_runner_flashinfer_mxfp4_quantization_spec_ngram": { @@ -137,11 +128,6 @@ class TestMoERunnerTriton(CustomTestCase): "--speculative-num-draft-tokens", "8", ], - "eval_kwargs": {"num_examples": 2}, - "env": { - "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", - "SGLANG_ENABLE_JIT_DEEPGEMM": "0", - }, }, "moe_runner_flashinfer_mxfp4_fp8_quantization": { "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, @@ -158,7 +144,6 @@ class TestMoERunnerTriton(CustomTestCase): "--mem-fraction-static", "0.95", ], - "eval_kwargs": {"num_examples": 2}, }, "moe_runner_flashinfer_cutedsl_fp8_quantization": { "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, @@ -175,7 +160,6 @@ class TestMoERunnerTriton(CustomTestCase): "--mem-fraction-static", "0.95", ], - "eval_kwargs": {"num_examples": 2}, }, "moe_runner_cutlass_w8a8_quantization": { "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, @@ -192,18 +176,13 @@ class TestMoERunnerTriton(CustomTestCase): "--mem-fraction-static", "0.95", ], - "eval_kwargs": {"num_examples": 2}, }, } def _run_config(self, config: dict) -> None: - model = config.get("model", self.DEFAULT_MODEL) + model = config["model"] other_args = config.get("other_args", []) - eval_kwargs = self.DEFAULT_EVAL_KWARGS | config.get("eval_kwargs", {}) - - env_overrides = config.get("env", {}) - saved_env = {k: os.environ.get(k) for k in env_overrides} - os.environ.update(env_overrides) + eval_kwargs = self.DEFAULT_EVAL_KWARGS process = popen_launch_server( model, @@ -222,23 +201,14 @@ def _run_config(self, config: dict) -> None: self.assertGreaterEqual(metrics["score"], 0.0) finally: kill_process_tree(process.pid) - for key, previous in saved_env.items(): - if previous is None: - os.environ.pop(key, None) - else: - os.environ[key] = previous - - -def _make_test(config_name: str, config: dict): - def test(self): - self._run_config(config) - - test.__name__ = f"test_{config_name}" - return test -for _name, _config in TestMoERunnerTriton.CONFIGS.items(): - setattr(TestMoERunnerTriton, f"test_{_name}", _make_test(_name, _config)) +for _name, _cfg in TestMoERunner.CONFIGS.items(): + setattr( + TestMoERunner, + f"test_{_name}", + (lambda self, cfg=_cfg: self._run_config(cfg)), + ) if __name__ == "__main__": From 9dd39ce51f2a11e3eac7135ae7d0a012f0bd9289 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 27 Oct 2025 21:24:59 -0400 Subject: [PATCH 009/150] Fix config issues --- test/srt/layers/moe/test_moe.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py index e1aa55b780d0..645548790d1e 100644 --- a/test/srt/layers/moe/test_moe.py +++ b/test/srt/layers/moe/test_moe.py @@ -37,6 +37,10 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, # Speculative decoding (EAGLE) on unquantized small MoE with Triton @@ -60,6 +64,10 @@ class TestMoERunner(CustomTestCase): "1", "--speculative-num-draft-tokens", "2", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, "moe_runner_triton_kernel_unquant_standard": { @@ -74,6 +82,10 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, "moe_runner_deep_gemm_awq_quantization": { @@ -90,6 +102,10 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, "moe_runner_flashinfer_trtllm_mxfp4_quantization": { @@ -127,6 +143,10 @@ class TestMoERunner(CustomTestCase): "NGRAM", "--speculative-num-draft-tokens", "8", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, "moe_runner_flashinfer_mxfp4_fp8_quantization": { @@ -143,6 +163,10 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, "moe_runner_flashinfer_cutedsl_fp8_quantization": { @@ -159,6 +183,10 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, "moe_runner_cutlass_w8a8_quantization": { @@ -175,6 +203,10 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, } From d092df0245965b463664f43b1003432ae4d3b69d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 27 Oct 2025 21:33:46 -0400 Subject: [PATCH 010/150] Fix config issues --- test/srt/layers/moe/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py index 645548790d1e..8881511f49c6 100644 --- a/test/srt/layers/moe/test_moe.py +++ b/test/srt/layers/moe/test_moe.py @@ -196,7 +196,7 @@ class TestMoERunner(CustomTestCase): "--moe-runner-backend", "cutlass", "--quantization", - "w8a8", + "w8a8_int8", "--tp", "2", "--max-total-tokens", From ac734f652081ba30cdd04bc251c658ab5f13933a Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 27 Oct 2025 22:08:32 -0400 Subject: [PATCH 011/150] Add default mxfp4 moe test model --- python/sglang/test/test_utils.py | 4 ++++ test/srt/layers/moe/test_moe.py | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 97642ace3d3d..a46af25579ca 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -71,6 +71,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 = "Qwen/Qwen3-1.7B-FP8" DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE = "gaunernst/DeepSeek-V2-Lite-Chat-FP8" +# MXFP4 models +# Standard MXFP4 MoE test model +DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE = "openai/gpt-oss-20b" + # W8A8 models DEFAULT_MODEL_NAME_FOR_TEST_W8A8 = "RedHatAI/Llama-3.2-3B-quantized.w8a8" DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8" diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py index 8881511f49c6..114b6c75f47f 100644 --- a/test/srt/layers/moe/test_moe.py +++ b/test/srt/layers/moe/test_moe.py @@ -6,6 +6,7 @@ from sglang.test.test_utils import ( DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -109,7 +110,7 @@ class TestMoERunner(CustomTestCase): ], }, "moe_runner_flashinfer_trtllm_mxfp4_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, "other_args": [ "--trust-remote-code", "--moe-runner-backend", @@ -126,7 +127,7 @@ class TestMoERunner(CustomTestCase): }, # Speculative decoding (NGRAM) with FlashInfer MXFP4 backend (differs from main path) "moe_runner_flashinfer_mxfp4_quantization_spec_ngram": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, "other_args": [ "--trust-remote-code", "--moe-runner-backend", @@ -149,8 +150,8 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, - "moe_runner_flashinfer_mxfp4_fp8_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "moe_runner_flashinfer_mxfp4_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, "other_args": [ "--trust-remote-code", "--moe-runner-backend", From 34baf323e2f3976358b07811dd841714a92bb30b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 28 Oct 2025 08:51:42 -0400 Subject: [PATCH 012/150] Add configs for auto backend choosing logic --- test/srt/layers/moe/test_moe.py | 63 ++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py index 114b6c75f47f..8657742f1ec0 100644 --- a/test/srt/layers/moe/test_moe.py +++ b/test/srt/layers/moe/test_moe.py @@ -27,25 +27,6 @@ class TestMoERunner(CustomTestCase): CONFIGS = { "moe_runner_triton_unquant_standard": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "triton_kernel", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - # Speculative decoding (EAGLE) on unquantized small MoE with Triton - "moe_runner_triton_unquant_spec_eagle": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", @@ -57,14 +38,6 @@ class TestMoERunner(CustomTestCase): "2048", "--mem-fraction-static", "0.95", - "--speculative-algorithm", - "EAGLE", - "--speculative-num-steps", - "1", - "--speculative-eagle-topk", - "1", - "--speculative-num-draft-tokens", - "2", "--attention-backend", "torch_native", "--sampling-backend", @@ -210,6 +183,40 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, + "moe_runner_auto_refactored": { # 'auto' where the potential backend has been refactored + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ + "--trust-remote-code", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_auto_not_refactored": { # 'auto' where the potential backend has not been refactored + "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--quantization", + "w8a8_int8", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, } def _run_config(self, config: dict) -> None: @@ -231,7 +238,7 @@ def _run_config(self, config: dict) -> None: ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreaterEqual(metrics["score"], 0.0) + self.assertGreaterEqual(metrics["score"], 0.48) finally: kill_process_tree(process.pid) From f01cb32e9bbea3bcd021a99a33254966914455d0 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 28 Oct 2025 09:20:57 -0400 Subject: [PATCH 013/150] Rename file and remove unnecessary configs --- test/srt/layers/moe/test_moe_runners.py | 221 ++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 test/srt/layers/moe/test_moe_runners.py diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py new file mode 100644 index 000000000000..ecae12ebc5c3 --- /dev/null +++ b/test/srt/layers/moe/test_moe_runners.py @@ -0,0 +1,221 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, + DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestMoERunner(CustomTestCase): + BASE_URL = DEFAULT_URL_FOR_TEST + TIMEOUT = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + DEFAULT_EVAL_KWARGS = { + "eval_name": "mmlu", + "num_examples": 5, + "num_threads": 1, + } + + CONFIGS = { + "moe_runner_triton_unquant_standard": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "triton", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_triton_kernel_unquant_standard": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "triton_kernel", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_deep_gemm_awq_quantization": { + "model": DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "deep_gemm", + "--quantization", + "awq", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_flashinfer_trtllm_mxfp4_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_trtllm", + "--quantization", + "mxfp4", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + ], + }, + # Speculative decoding (NGRAM) with FlashInfer MXFP4 backend (differs from main path) + "moe_runner_flashinfer_mxfp4_quantization_spec_ngram": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--quantization", + "mxfp4", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--speculative-algorithm", + "NGRAM", + "--speculative-num-draft-tokens", + "8", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_flashinfer_mxfp4_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--quantization", + "mxfp4", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_flashinfer_cutedsl_fp8_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_cutedsl", + "--quantization", + "fp8", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + "moe_runner_cutlass_w8a8_quantization": { + "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "cutlass", + "--quantization", + "w8a8_int8", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", + ], + }, + } + + def _run_config(self, config: dict) -> None: + model = config["model"] + other_args = config.get("other_args", []) + eval_kwargs = self.DEFAULT_EVAL_KWARGS + + process = popen_launch_server( + model, + self.BASE_URL, + timeout=self.TIMEOUT, + other_args=other_args, + ) + try: + args = SimpleNamespace( + base_url=self.BASE_URL, + model=model, + **eval_kwargs, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreaterEqual(metrics["score"], 0.48) + finally: + kill_process_tree(process.pid) + + +for _name, _cfg in TestMoERunner.CONFIGS.items(): + setattr( + TestMoERunner, + f"test_{_name}", + (lambda self, cfg=_cfg: self._run_config(cfg)), + ) + + +if __name__ == "__main__": + unittest.main() From 46b34eda926b62d5b15cc1b37abc1f287f59397e Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 28 Oct 2025 12:00:17 -0400 Subject: [PATCH 014/150] Simplify configs --- test/srt/layers/moe/test_moe.py | 255 ------------------------ test/srt/layers/moe/test_moe_runners.py | 67 ++++--- 2 files changed, 35 insertions(+), 287 deletions(-) delete mode 100644 test/srt/layers/moe/test_moe.py diff --git a/test/srt/layers/moe/test_moe.py b/test/srt/layers/moe/test_moe.py deleted file mode 100644 index 8657742f1ec0..000000000000 --- a/test/srt/layers/moe/test_moe.py +++ /dev/null @@ -1,255 +0,0 @@ -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, - DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, - DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, - DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, - DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - - -class TestMoERunner(CustomTestCase): - BASE_URL = DEFAULT_URL_FOR_TEST - TIMEOUT = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - DEFAULT_EVAL_KWARGS = { - "eval_name": "mmlu", - "num_examples": 5, - "num_threads": 1, - } - - CONFIGS = { - "moe_runner_triton_unquant_standard": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "triton", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_triton_kernel_unquant_standard": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "triton_kernel", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_deep_gemm_awq_quantization": { - "model": DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "deep_gemm", - "--quantization", - "awq", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_flashinfer_trtllm_mxfp4_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "flashinfer_trtllm", - "--quantization", - "mxfp4", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - ], - }, - # Speculative decoding (NGRAM) with FlashInfer MXFP4 backend (differs from main path) - "moe_runner_flashinfer_mxfp4_quantization_spec_ngram": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "flashinfer_mxfp4", - "--quantization", - "mxfp4", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--speculative-algorithm", - "NGRAM", - "--speculative-num-draft-tokens", - "8", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_flashinfer_mxfp4_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "flashinfer_mxfp4", - "--quantization", - "mxfp4", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_flashinfer_cutedsl_fp8_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "flashinfer_cutedsl", - "--quantization", - "fp8", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_cutlass_w8a8_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, - "other_args": [ - "--trust-remote-code", - "--moe-runner-backend", - "cutlass", - "--quantization", - "w8a8_int8", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_auto_refactored": { # 'auto' where the potential backend has been refactored - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, - "other_args": [ - "--trust-remote-code", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - "moe_runner_auto_not_refactored": { # 'auto' where the potential backend has not been refactored - "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, - "other_args": [ - "--trust-remote-code", - "--quantization", - "w8a8_int8", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", - "--attention-backend", - "torch_native", - "--sampling-backend", - "pytorch", - ], - }, - } - - def _run_config(self, config: dict) -> None: - model = config["model"] - other_args = config.get("other_args", []) - eval_kwargs = self.DEFAULT_EVAL_KWARGS - - process = popen_launch_server( - model, - self.BASE_URL, - timeout=self.TIMEOUT, - other_args=other_args, - ) - try: - args = SimpleNamespace( - base_url=self.BASE_URL, - model=model, - **eval_kwargs, - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreaterEqual(metrics["score"], 0.48) - finally: - kill_process_tree(process.pid) - - -for _name, _cfg in TestMoERunner.CONFIGS.items(): - setattr( - TestMoERunner, - f"test_{_name}", - (lambda self, cfg=_cfg: self._run_config(cfg)), - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index ecae12ebc5c3..8f32969b3785 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -26,7 +26,7 @@ class TestMoERunner(CustomTestCase): } CONFIGS = { - "moe_runner_triton_unquant_standard": { + "moe_runner_auto": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", @@ -44,12 +44,12 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, - "moe_runner_triton_kernel_unquant_standard": { + "moe_runner_triton": { "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", "--moe-runner-backend", - "triton_kernel", + "triton", "--tp", "2", "--max-total-tokens", @@ -62,14 +62,12 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, - "moe_runner_deep_gemm_awq_quantization": { - "model": DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, + "moe_runner_triton_kernel": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", "--moe-runner-backend", - "deep_gemm", - "--quantization", - "awq", + "triton_kernel", "--tp", "2", "--max-total-tokens", @@ -82,48 +80,57 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, - "moe_runner_flashinfer_trtllm_mxfp4_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, + "moe_runner_flashinfer_cutlass": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", "--moe-runner-backend", - "flashinfer_trtllm", - "--quantization", - "mxfp4", + "flashinfer_cutlass", "--tp", "2", "--max-total-tokens", "2048", "--mem-fraction-static", "0.95", + "--attention-backend", + "torch_native", + "--sampling-backend", + "pytorch", ], }, - # Speculative decoding (NGRAM) with FlashInfer MXFP4 backend (differs from main path) - "moe_runner_flashinfer_mxfp4_quantization_spec_ngram": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, + "moe_runner_deep_gemm": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", "--moe-runner-backend", - "flashinfer_mxfp4", - "--quantization", - "mxfp4", + "deep_gemm", "--tp", "2", "--max-total-tokens", "2048", "--mem-fraction-static", "0.95", - "--speculative-algorithm", - "NGRAM", - "--speculative-num-draft-tokens", - "8", "--attention-backend", "torch_native", "--sampling-backend", "pytorch", ], }, - "moe_runner_flashinfer_mxfp4_quantization": { + "moe_runner_flashinfer_trtllm": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "other_args": [ + "--trust-remote-code", + "--moe-runner-backend", + "flashinfer_trtllm", + "--tp", + "2", + "--max-total-tokens", + "2048", + "--mem-fraction-static", + "0.95", + ], + }, + "moe_runner_flashinfer_mxfp4": { "model": DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, "other_args": [ "--trust-remote-code", @@ -143,14 +150,12 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, - "moe_runner_flashinfer_cutedsl_fp8_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + "moe_runner_flashinfer_cutedsl": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", "--moe-runner-backend", "flashinfer_cutedsl", - "--quantization", - "fp8", "--tp", "2", "--max-total-tokens", @@ -163,14 +168,12 @@ class TestMoERunner(CustomTestCase): "pytorch", ], }, - "moe_runner_cutlass_w8a8_quantization": { - "model": DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, + "moe_runner_cutlass": { + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, "other_args": [ "--trust-remote-code", "--moe-runner-backend", "cutlass", - "--quantization", - "w8a8_int8", "--tp", "2", "--max-total-tokens", From 53e9b18fc7cf42b38013204bc2374c77c28557b7 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 28 Oct 2025 14:26:38 -0400 Subject: [PATCH 015/150] Add helpful comments --- test/srt/layers/moe/test_moe_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 8f32969b3785..1ee8614c166c 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -81,7 +81,7 @@ class TestMoERunner(CustomTestCase): ], }, "moe_runner_flashinfer_cutlass": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, # requires model with modelopt_fp4 quantization "other_args": [ "--trust-remote-code", "--moe-runner-backend", From 3e36425665a3343210a51dacad9c45e522e3ffdf Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 28 Oct 2025 15:00:56 -0400 Subject: [PATCH 016/150] Correct comment --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e673febea827..a76cd301da88 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1305,7 +1305,7 @@ def _handle_moe_kernel_config(self): if self.moe_runner_backend == "flashinfer_cutlass": assert ( self.quantization == "modelopt_fp4" - ), "modelopt_fp4 quantization is required for Flashinfer MOE" + ), "modelopt_fp4 quantization is required for Flashinfer Cutlass MOE" assert self.ep_size in [ 1, self.tp_size, From 80812911298a41b6eb39242bd019c52ef08ce211 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 28 Oct 2025 15:04:06 -0400 Subject: [PATCH 017/150] Adjust default model for each test case --- test/srt/layers/moe/test_moe_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 1ee8614c166c..456c88098ea2 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -117,7 +117,7 @@ class TestMoERunner(CustomTestCase): ], }, "moe_runner_flashinfer_trtllm": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "model": DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, # modelopt_fp4 or fp8 quantization is required for Flashinfer trtllm MOE "other_args": [ "--trust-remote-code", "--moe-runner-backend", From e0193ecea89b28bdd1235baced41c1ae68e76f2a Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 3 Nov 2025 10:48:12 -0500 Subject: [PATCH 018/150] Add default moe NVFP4 model name for test --- python/sglang/test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2b65561c2a44..575898092940 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -58,6 +58,7 @@ # NVFP4 models DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-R1-0528-FP4" +DEFAULT_MODEL_NAME_FOR_TEST_MOE_NVFP4 = "nvidia/Qwen3-30B-A3B-FP4" # FP8 models DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" @@ -950,7 +951,6 @@ def run_score_benchmark( ) async def _run_benchmark(): - # Load tokenizer for generating test data from sglang.srt.utils.hf_transformers_utils import get_tokenizer From aa9382ae0f4e39050d28bf1bca484f55c64f01a2 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 3 Nov 2025 10:50:05 -0500 Subject: [PATCH 019/150] Wire default NVFP4 moe model into moe integration tests --- test/srt/layers/moe/test_moe_runners.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 456c88098ea2..b69116bc0503 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -6,6 +6,7 @@ from sglang.test.test_utils import ( DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE, + DEFAULT_MODEL_NAME_FOR_TEST_MOE_NVFP4, DEFAULT_MODEL_NAME_FOR_TEST_MXFP4_WITH_MOE, DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, @@ -81,7 +82,7 @@ class TestMoERunner(CustomTestCase): ], }, "moe_runner_flashinfer_cutlass": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, # requires model with modelopt_fp4 quantization + "model": DEFAULT_MODEL_NAME_FOR_TEST_MOE_NVFP4, # requires model with modelopt_fp4 quantization "other_args": [ "--trust-remote-code", "--moe-runner-backend", From 3a92cebc519357b8f42d54f21d1eb50c26decef3 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 3 Nov 2025 15:01:57 -0500 Subject: [PATCH 020/150] Wire default NVFP4 moe model into moe integration tests configs --- test/srt/layers/moe/test_moe_runners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index b69116bc0503..63e0e1cd6b42 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -152,7 +152,7 @@ class TestMoERunner(CustomTestCase): ], }, "moe_runner_flashinfer_cutedsl": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "model": DEFAULT_MODEL_NAME_FOR_TEST_MOE_NVFP4, "other_args": [ "--trust-remote-code", "--moe-runner-backend", @@ -170,7 +170,7 @@ class TestMoERunner(CustomTestCase): ], }, "moe_runner_cutlass": { - "model": DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, + "model": DEFAULT_MODEL_NAME_FOR_TEST_MOE_NVFP4, "other_args": [ "--trust-remote-code", "--moe-runner-backend", From f8e09fa4329659ca7daced355ffc7d72005e37ca Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 3 Nov 2025 16:03:58 -0500 Subject: [PATCH 021/150] Remove unncessary args --- test/srt/layers/moe/test_moe_runners.py | 54 ------------------------- 1 file changed, 54 deletions(-) diff --git a/test/srt/layers/moe/test_moe_runners.py b/test/srt/layers/moe/test_moe_runners.py index 63e0e1cd6b42..ac5b3829e391 100644 --- a/test/srt/layers/moe/test_moe_runners.py +++ b/test/srt/layers/moe/test_moe_runners.py @@ -33,12 +33,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "triton", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -51,12 +45,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "triton", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -69,12 +57,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "triton_kernel", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -87,12 +69,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "flashinfer_cutlass", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -105,12 +81,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "deep_gemm", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -123,12 +93,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "flashinfer_trtllm", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", ], }, "moe_runner_flashinfer_mxfp4": { @@ -139,12 +103,6 @@ class TestMoERunner(CustomTestCase): "flashinfer_mxfp4", "--quantization", "mxfp4", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -157,12 +115,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "flashinfer_cutedsl", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", @@ -175,12 +127,6 @@ class TestMoERunner(CustomTestCase): "--trust-remote-code", "--moe-runner-backend", "cutlass", - "--tp", - "2", - "--max-total-tokens", - "2048", - "--mem-fraction-static", - "0.95", "--attention-backend", "torch_native", "--sampling-backend", From aad84ef8ec74b3bb7f06deb01c33b1677087e463 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 5 Nov 2025 09:59:15 -0500 Subject: [PATCH 022/150] Add to not_in_ci list --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1e9df7340997..1532f6d19f8d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -234,6 +234,7 @@ class TestFile: TestFile("hicache/test_hicache_storage_benchmark.py"), TestFile("hicache/test_hicache_storage_e2e.py"), TestFile("layers/attention/nsa/test_act_quant_triton.py"), + TestFile("layers/moe/test_moe_runners.py"), TestFile("lora/test_chunked_sgmv_backend.py"), TestFile("lora/test_lora_llama4.py"), TestFile("models/lora/test_lora.py"), From ccd9ecad14bce4f3e3d1e40cbd48ebf276502e3d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 5 Nov 2025 21:09:25 -0500 Subject: [PATCH 023/150] fix lint issues --- python/sglang/srt/layers/moe/lora_moe.py | 128 +++++++++ python/sglang/srt/lora/lora_manager.py | 18 +- python/sglang/srt/lora/mem_pool.py | 146 ++++++++-- python/sglang/srt/lora/moe_dispatch.py | 61 +++++ python/sglang/srt/lora/triton_ops/__init__.py | 2 + .../lora/triton_ops/per_expert_lora_moe.py | 253 ++++++++++++++++++ 6 files changed, 580 insertions(+), 28 deletions(-) create mode 100644 python/sglang/srt/layers/moe/lora_moe.py create mode 100644 python/sglang/srt/lora/moe_dispatch.py create mode 100644 python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py diff --git a/python/sglang/srt/layers/moe/lora_moe.py b/python/sglang/srt/layers/moe/lora_moe.py new file mode 100644 index 000000000000..b37f3af52576 --- /dev/null +++ b/python/sglang/srt/layers/moe/lora_moe.py @@ -0,0 +1,128 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""FusedMoE layer with LoRA support.""" + +import torch +from torch import nn + +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopKOutput +from sglang.srt.lora.backend.base_backend import BaseLoRABackend + + +class FusedMoEWithLoRA(nn.Module): + """ + Wrapper around FusedMoE that adds parallel LoRA computation. + + Design: Base MoE and LoRA Delta run independently and merge at the end. + This preserves SGLang's existing 3-stage MoE architecture unchanged. + """ + + def __init__( + self, + base_moe: FusedMoE, + lora_backend: BaseLoRABackend, + ): + super().__init__() + self.base_moe = base_moe + self.lora_backend = lora_backend + self.lora_enabled = False + + # LoRA tensors will be set by LoRAManager + self.lora_a_weights = None + self.lora_b_weights = None + + def set_lora_info( + self, + lora_a_weights: torch.Tensor, + lora_b_weights: torch.Tensor, + ): + """Set LoRA weight tensors from memory pool.""" + self.lora_enabled = True + self.lora_a_weights = lora_a_weights + self.lora_b_weights = lora_b_weights + + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): + """ + Forward pass with parallel LoRA computation. + + Flow: + 1. Base MoE forward + 2. Parallel LoRA delta computation (if enabled, added in-place) + 3. Return modified base_output + """ + # Run base MoE + base_output = self.base_moe.forward(hidden_states, topk_output, **kwargs) + + # If LoRA is enabled, compute delta and add in-place for memory efficiency + if self.lora_enabled and self.lora_a_weights is not None: + self._compute_lora_delta(hidden_states, topk_output, base_output) + + return base_output + + def _compute_lora_delta( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + base_output: torch.Tensor, + ) -> None: + """ + Compute LoRA delta using per-expert LoRA weights and add to base_output in-place. + + Dispatch tokens to experts with LoRA-aware routing and compute per-expert deltas. + """ + from sglang.srt.lora.moe_dispatch import per_lora_moe_dispatch + from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( + per_expert_lora_forward, + ) + + # Get dispatch info from TopKOutput + topk_ids = topk_output.topk_ids # [num_tokens, top_k] + topk_weights = topk_output.topk_weights # [num_tokens, top_k] + + # Get LoRA batch info from backend + batch_info = self.lora_backend.batch_info + weight_indices = batch_info.weight_indices # [num_tokens] + lora_ranks = batch_info.lora_ranks # [num_loras] + scalings = batch_info.scalings # [num_loras] + + num_experts = self.base_moe.num_experts + num_loras = self.lora_a_weights.shape[0] + + # Dispatch tokens to (lora, expert) pairs + token_ids, expert_ids, _ = per_lora_moe_dispatch( + topk_ids=topk_ids, + topk_weights=topk_weights, + weight_indices=weight_indices, + num_experts=num_experts, + num_loras=num_loras, + ) + + # Get LoRA IDs for dispatched tokens + lora_ids = weight_indices[token_ids] + + # Compute per-expert LoRA forward (adds to base_output in-place) + per_expert_lora_forward( + hidden_states=hidden_states, + lora_a_weights=self.lora_a_weights, + lora_b_weights=self.lora_b_weights, + token_ids=token_ids, + expert_ids=expert_ids, + lora_ids=lora_ids, + lora_ranks=lora_ranks, + lora_scalings=scalings, + num_experts=num_experts, + base_output=base_output, + ) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 19ff874dc1da..0a39b26b907b 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -441,6 +441,14 @@ def set_lora_module(self, module_name, module): replace_submodule(self.base_model, module_name, lora_module) return lora_module + def set_moe_lora_module(self, module_name, module): + """Wrap MoE module with LoRA support.""" + from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA + + lora_moe = FusedMoEWithLoRA(module, self.lora_backend) + replace_submodule(self.base_model, module_name, lora_moe) + return lora_moe + def init_lora_modules(self): # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ @@ -458,8 +466,16 @@ def init_lora_modules(self): ) and not self.base_model.should_apply_lora(module_name): continue + # Check if this is an MoE module first + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + + if isinstance(module, FusedMoE): + layer_id = get_layer_id(module_name) + self.lora_modules[layer_id][module_name] = self.set_moe_lora_module( + module_name, module + ) # The module should be converted if it is included in target_names - if module_name.split(".")[-1] in self.target_modules: + elif module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index f6375361700e..c1bc454ae336 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -70,10 +70,9 @@ def __init__( self.eviction_policy = get_eviction_policy(eviction_policy) # Both A_buffer and B_buffer maps lora weight names to its buffer space. - # A_buffer contains num_layer number of row-major tensors with shape - # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim) - # B_buffer contains num_layer number of column-major tensors with shape - # (stacked_num, max_loras_per_batch, output_dim, max_lora_dim) + # Standard LoRA (3D): [num_loras, rank, hidden_dim] + # MoE LoRA (4D): [num_loras, num_experts, rank, hidden_dim] + # The dimensionality is determined by the module type (MoE vs standard) self.A_buffer: Dict[str, List[torch.Tensor]] = {} self.B_buffer: Dict[str, List[torch.Tensor]] = {} @@ -108,6 +107,11 @@ def _can_support(config: LoRAConfig) -> bool: else: return all(_can_support(x) for x in config) + def is_moe_module(self, module_name: str) -> bool: + """Check if module is part of MoE experts.""" + moe_patterns = ["block_sparse_moe.experts", "experts.", "mlp.experts"] + return any(pattern in module_name for pattern in moe_patterns) + def get_lora_A_shape( self, module_name: str, @@ -116,7 +120,11 @@ def get_lora_A_shape( layer_idx: int, ) -> Tuple[int]: """ - Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. + Get shape for LoRA A weights. Automatically returns 3D or 4D based on module type. + + Returns: + - Standard: [num_loras, rank, hidden_dim] + - MoE: [num_loras, num_experts, rank, hidden_dim] """ input_dim, _ = get_hidden_dim( module_name, self.base_hf_config, base_model, layer_idx @@ -124,11 +132,17 @@ def get_lora_A_shape( c = get_stacked_multiply(module_name) if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: input_dim = divide(input_dim, self.tp_size) - return ( - self.max_loras_per_batch, - max_lora_dim * c, - input_dim, - ) + + # Check if MoE module and return appropriate shape + if self.is_moe_module(module_name): + num_experts = getattr( + self.base_hf_config, + "num_local_experts", + getattr(self.base_hf_config, "num_experts", 0), + ) + return (self.max_loras_per_batch, num_experts, max_lora_dim, input_dim) + else: + return (self.max_loras_per_batch, max_lora_dim * c, input_dim) def get_lora_B_shape( self, @@ -138,18 +152,28 @@ def get_lora_B_shape( layer_idx: int, ) -> Tuple[int]: """ - Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. + Get shape for LoRA B weights. Automatically returns 3D or 4D based on module type. + + Returns: + - Standard: [num_loras, output_dim, rank] + - MoE: [num_loras, num_experts, output_dim, rank] """ _, output_dim = get_hidden_dim( module_name, self.base_hf_config, base_model, layer_idx ) if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: output_dim = divide(output_dim, self.tp_size) - return ( - self.max_loras_per_batch, - output_dim, - max_lora_dim, - ) + + # Check if MoE module and return appropriate shape + if self.is_moe_module(module_name): + num_experts = getattr( + self.base_hf_config, + "num_local_experts", + getattr(self.base_hf_config, "num_experts", 0), + ) + return (self.max_loras_per_batch, num_experts, output_dim, max_lora_dim) + else: + return (self.max_loras_per_batch, output_dim, max_lora_dim) def init_buffers(self, base_model: torch.nn.Module): device = next(base_model.parameters()).device @@ -174,6 +198,7 @@ def init_buffer( for idx in range(self.num_layer) ] + # Shape functions automatically handle both 3D (standard) and 4D (MoE) init_buffer( self.A_buffer, self.target_modules, @@ -279,18 +304,40 @@ def load_lora_weight_tensor( lora_rank = lora_adapter.config.r for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights - temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { + # - Standard: module_name -> torch.Tensor + # - MoE: module_name -> Dict[expert_id -> torch.Tensor] + temp_A_buffer: Dict[str, Union[torch.Tensor, Dict[int, torch.Tensor]]] = { target_module: None for target_module in self.A_buffer } - temp_B_buffer: Dict[str, Optional[torch.Tensor]] = { + temp_B_buffer: Dict[str, Union[torch.Tensor, Dict[int, torch.Tensor]]] = { target_module: None for target_module in self.B_buffer } + for name, weights in layer_weights.items(): target_module = get_target_module_name(name, self.target_modules) - if "lora_A" in name: - temp_A_buffer[target_module] = weights + + # Check if this is an MoE weight (has expert index in name) + import re + + expert_match = re.search(r"experts\.(\d+)\.", name) + + if expert_match and self.is_moe_module(target_module): + # MoE weight - multiple tensors per module (one per expert) + if temp_A_buffer[target_module] is None: + temp_A_buffer[target_module] = {} + temp_B_buffer[target_module] = {} + + expert_id = int(expert_match.group(1)) + if "lora_A" in name: + temp_A_buffer[target_module][expert_id] = weights + else: + temp_B_buffer[target_module][expert_id] = weights else: - temp_B_buffer[target_module] = weights + # Standard weight - single tensor per module + if "lora_A" in name: + temp_A_buffer[target_module] = weights + else: + temp_B_buffer[target_module] = weights if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] @@ -303,6 +350,25 @@ def load_lora_weight_tensor( # Skip weight slicing if the weight is not present in the adapter continue + # Handle MoE modules (they contain dicts of per-expert tensors) + if isinstance(temp_A_buffer[target_module], dict): + # Slice each expert's weights individually + for expert_id in temp_A_buffer[target_module].keys(): + temp_A_buffer[target_module][expert_id] = ( + module.slice_lora_a_weights( + temp_A_buffer[target_module][expert_id], + self.tp_rank, + ) + ) + temp_B_buffer[target_module][expert_id] = ( + module.slice_lora_b_weights( + temp_B_buffer[target_module][expert_id], + self.tp_rank, + ) + ) + continue + + # Handle standard modules temp_A_buffer[target_module] = module.slice_lora_a_weights( temp_A_buffer[target_module], self.tp_rank ) @@ -310,23 +376,49 @@ def load_lora_weight_tensor( temp_B_buffer[target_module], self.tp_rank ) + # Load weights into buffers (handles both 3D standard and 4D MoE) for name, weights in temp_A_buffer.items(): - c = get_stacked_multiply(name) + c = get_stacked_multiply(name) # TODO: delete this target_buffer = self.A_buffer[name][layer_id] - buffer_view = target_buffer[buffer_id, : lora_rank * c, :] - load_lora_weight_tensor(buffer_view, weights) + + if isinstance(weights, dict): + # MoE: multiple tensors per module (one per expert) + for expert_id, expert_weight in weights.items(): + # Buffer shape: [num_loras, num_experts, max_rank, hidden_dim] + buffer_view = target_buffer[buffer_id, expert_id, :lora_rank, :] + load_lora_weight_tensor(buffer_view, expert_weight) + else: + # Standard: single tensor per module + c = get_stacked_multiply(name) + buffer_view = target_buffer[buffer_id, : lora_rank * c, :] + load_lora_weight_tensor(buffer_view, weights) for name, weights in temp_B_buffer.items(): target_buffer = self.B_buffer[name][layer_id] - buffer_view = target_buffer[buffer_id, :, :lora_rank] - load_lora_weight_tensor(buffer_view, weights) + + if isinstance(weights, dict): + # MoE: multiple tensors per module (one per expert) + for expert_id, expert_weight in weights.items(): + # Buffer shape: [num_loras, num_experts, intermediate_dim, max_rank] + buffer_view = target_buffer[buffer_id, expert_id, :, :lora_rank] + load_lora_weight_tensor(buffer_view, expert_weight) + else: + # Standard: single tensor per module + buffer_view = target_buffer[buffer_id, :, :lora_rank] + load_lora_weight_tensor(buffer_view, weights) def get_tensor( self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: + """ + Get LoRA tensor buffer (automatically handles both 3D and 4D tensors). + + Returns: + - 3D tensor [num_loras, rank, hidden] for standard modules + - 4D tensor [num_loras, num_experts, rank, hidden] for MoE modules + """ if lora_type == LoRAType.LORA_A: return self.A_buffer[target_module][layer_id] - return self.B_buffer[target_module][layer_id] def get_buffer_id(self, lora_uid: str): diff --git a/python/sglang/srt/lora/moe_dispatch.py b/python/sglang/srt/lora/moe_dispatch.py new file mode 100644 index 000000000000..3fc4c7a709d0 --- /dev/null +++ b/python/sglang/srt/lora/moe_dispatch.py @@ -0,0 +1,61 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""MoE-specific LoRA dispatch utilities.""" + +import torch + + +def per_lora_moe_dispatch( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + weight_indices: torch.Tensor, + num_experts: int, + num_loras: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Dispatch tokens to experts with per-LoRA routing. + + Args: + topk_ids: [num_tokens, top_k] - Expert IDs selected by router + topk_weights: [num_tokens, top_k] - Router weights + weight_indices: [num_tokens] - LoRA adapter ID for each token + num_experts: Total number of experts + num_loras: Total number of LoRA adapters + + Returns: + sorted_token_ids: Token indices sorted by (lora_id, expert_id) + sorted_expert_ids: Corresponding expert IDs + sorted_weights: Corresponding router weights + """ + num_tokens, top_k = topk_ids.shape + device = topk_ids.device + + # Flatten topk dimensions: [num_tokens * top_k] + flat_topk_ids = topk_ids.flatten() + flat_topk_weights = topk_weights.flatten() + flat_token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k) + flat_lora_ids = weight_indices.repeat_interleave(top_k) + + # Create composite key for sorting: lora_id * num_experts + expert_id + composite_key = flat_lora_ids * num_experts + flat_topk_ids + + # Sort by composite key to group by (lora_id, expert_id) + sorted_indices = torch.argsort(composite_key) + + sorted_token_ids = flat_token_ids[sorted_indices] + sorted_expert_ids = flat_topk_ids[sorted_indices] + sorted_weights = flat_topk_weights[sorted_indices] + + return sorted_token_ids, sorted_expert_ids, sorted_weights diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 74a2e84a2c40..7d5b09e01962 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,6 +1,7 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward from .gate_up_lora_b import gate_up_lora_b_fwd +from .per_expert_lora_moe import per_expert_lora_forward from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_b import sgemm_lora_b_fwd @@ -12,4 +13,5 @@ "sgemm_lora_b_fwd", "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", + "per_expert_lora_forward", ] diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py new file mode 100644 index 000000000000..9132d0200fcf --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -0,0 +1,253 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Per-expert LoRA computation kernel for MoE layers.""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_expert_lora_kernel( + # Input/Output pointers + hidden_states_ptr, + lora_a_weights_ptr, + lora_b_weights_ptr, + output_ptr, + # Dispatch info + token_ids_ptr, + expert_ids_ptr, + lora_ids_ptr, + # Dimensions + hidden_dim: tl.constexpr, + intermediate_dim: tl.constexpr, + max_rank: tl.constexpr, + num_experts: tl.constexpr, + num_tokens: tl.constexpr, + # Strides for 4D LoRA weights [num_loras, num_experts, *, *] + lora_a_stride_lora: tl.constexpr, + lora_a_stride_expert: tl.constexpr, + lora_a_stride_rank: tl.constexpr, + lora_a_stride_hidden: tl.constexpr, + lora_b_stride_lora: tl.constexpr, + lora_b_stride_expert: tl.constexpr, + lora_b_stride_intermediate: tl.constexpr, + lora_b_stride_rank: tl.constexpr, + # LoRA ranks per adapter + lora_ranks_ptr, + # Scaling factors per adapter + lora_scalings_ptr, + # Block sizes + BLOCK_HIDDEN: tl.constexpr, + BLOCK_INTERMEDIATE: tl.constexpr, + BLOCK_RANK: tl.constexpr, +): + """ + Compute per-expert LoRA delta: delta = B @ A @ hidden_states. + + Grid: (spatial_tiles, intermediate_slices, num_loras) + - spatial_tiles: Number of token tiles + - intermediate_slices: Number of output dimension tiles + - num_loras: Process each LoRA adapter in parallel + """ + # Grid IDs + token_tile_id = tl.program_id(0) + output_tile_id = tl.program_id(1) + lora_id = tl.program_id(2) + + # Get rank and scaling for this LoRA adapter + rank = tl.load(lora_ranks_ptr + lora_id) + scaling = tl.load(lora_scalings_ptr + lora_id) + + # Early exit if rank is 0 + if rank == 0: + return + + # Token range for this tile + token_start = token_tile_id * BLOCK_HIDDEN + token_end = tl.minimum(token_start + BLOCK_HIDDEN, num_tokens) + + # Output dimension range + out_start = output_tile_id * BLOCK_INTERMEDIATE + out_end = tl.minimum(out_start + BLOCK_INTERMEDIATE, intermediate_dim) + + # Process each token in this tile + for token_idx in range(token_start, token_end): + if token_idx >= num_tokens: + break + + # Load dispatch info for this token + actual_token_id = tl.load(token_ids_ptr + token_idx) + expert_id = tl.load(expert_ids_ptr + token_idx) + token_lora_id = tl.load(lora_ids_ptr + token_idx) + + # Skip if this token doesn't belong to current LoRA + if token_lora_id != lora_id: + continue + + # Load hidden states for this token: [hidden_dim] + hidden_ptr = hidden_states_ptr + actual_token_id * hidden_dim + hidden_offs = tl.arange(0, BLOCK_HIDDEN) + hidden_mask = hidden_offs < hidden_dim + hidden = tl.load(hidden_ptr + hidden_offs, mask=hidden_mask, other=0.0) + + # Compute A @ hidden: [rank] = [rank, hidden_dim] @ [hidden_dim] + intermediate_a = tl.zeros([BLOCK_RANK], dtype=tl.float32) + + for k_tile in range(0, tl.cdiv(hidden_dim, BLOCK_HIDDEN)): + k_start = k_tile * BLOCK_HIDDEN + k_offs = tl.arange(0, BLOCK_HIDDEN) + k_start + k_mask = k_offs < hidden_dim + + # Load from hidden states + h_vals = tl.load(hidden_ptr + k_offs, mask=k_mask, other=0.0) + + # Load LoRA A weights: [rank, hidden_dim] + for r in range(BLOCK_RANK): + if r >= rank: + break + lora_a_offset = ( + lora_id * lora_a_stride_lora + + expert_id * lora_a_stride_expert + + r * lora_a_stride_rank + + k_start * lora_a_stride_hidden + ) + a_vals = tl.load( + lora_a_weights_ptr + lora_a_offset + k_offs, + mask=k_mask, + other=0.0, + ) + intermediate_a[r] += tl.sum(a_vals * h_vals) + + # Compute B @ intermediate_a: [intermediate_dim] = [intermediate_dim, rank] @ [rank] + out_offs = tl.arange(0, BLOCK_INTERMEDIATE) + out_start + out_mask = out_offs < intermediate_dim + + output_vals = tl.zeros([BLOCK_INTERMEDIATE], dtype=tl.float32) + + for r in range(BLOCK_RANK): + if r >= rank: + break + + # Load LoRA B weights: [intermediate_dim, rank] + lora_b_offset = ( + lora_id * lora_b_stride_lora + + expert_id * lora_b_stride_expert + + out_start * lora_b_stride_intermediate + + r * lora_b_stride_rank + ) + b_vals = tl.load( + lora_b_weights_ptr + + lora_b_offset + + out_offs * lora_b_stride_intermediate, + mask=out_mask, + other=0.0, + ) + output_vals += b_vals * intermediate_a[r] + + # Scale and accumulate to output + output_vals *= scaling + output_offset = actual_token_id * intermediate_dim + out_start + tl.atomic_add(output_ptr + output_offset + out_offs, output_vals, mask=out_mask) + + +def per_expert_lora_forward( + hidden_states: torch.Tensor, + lora_a_weights: torch.Tensor, + lora_b_weights: torch.Tensor, + token_ids: torch.Tensor, + expert_ids: torch.Tensor, + lora_ids: torch.Tensor, + lora_ranks: torch.Tensor, + lora_scalings: torch.Tensor, + num_experts: int, + base_output: torch.Tensor = None, +) -> torch.Tensor: + """ + Forward pass for per-expert LoRA computation. + + Args: + hidden_states: [num_tokens, hidden_dim] + lora_a_weights: [num_loras, num_experts, max_rank, hidden_dim] + lora_b_weights: [num_loras, num_experts, intermediate_dim, max_rank] + token_ids: [num_dispatched] - Original token indices + expert_ids: [num_dispatched] - Expert ID for each dispatched token + lora_ids: [num_dispatched] - LoRA ID for each dispatched token + lora_ranks: [num_loras] - Rank for each LoRA + lora_scalings: [num_loras] - Scaling factor for each LoRA + num_experts: Total number of experts + base_output: [num_tokens, intermediate_dim] - Base MoE output (modified in-place) + + Returns: + output: [num_tokens, intermediate_dim] - Base output + LoRA delta (in-place) + """ + num_tokens, hidden_dim = hidden_states.shape + num_loras, _, intermediate_dim, max_rank = lora_b_weights.shape + num_dispatched = token_ids.shape[0] + + # Initialize or reuse output tensor for in-place addition + if base_output is None: + output = torch.zeros( + num_tokens, + intermediate_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + else: + output = base_output + + # Block sizes (tuned for typical dimensions) + BLOCK_HIDDEN = 128 + BLOCK_INTERMEDIATE = 128 + BLOCK_RANK = 64 + + # Grid dimensions: (spatial_tiles, intermediate_slices, num_loras) + grid = ( + triton.cdiv(num_dispatched, BLOCK_HIDDEN), + triton.cdiv(intermediate_dim, BLOCK_INTERMEDIATE), + num_loras, + ) + + _per_expert_lora_kernel[grid]( + hidden_states, + lora_a_weights, + lora_b_weights, + output, + token_ids, + expert_ids, + lora_ids, + hidden_dim, + intermediate_dim, + max_rank, + num_experts, + num_dispatched, + # LoRA A strides: [num_loras, num_experts, max_rank, hidden_dim] + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + lora_a_weights.stride(3), + # LoRA B strides: [num_loras, num_experts, intermediate_dim, max_rank] + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + lora_b_weights.stride(3), + lora_ranks, + lora_scalings, + BLOCK_HIDDEN, + BLOCK_INTERMEDIATE, + BLOCK_RANK, + ) + + return output From d0c48133c12d0abbda81d0df60523a1cf9cdb141 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Thu, 6 Nov 2025 14:01:55 -0500 Subject: [PATCH 024/150] Clean up code --- python/sglang/srt/lora/layers.py | 4 ++++ python/sglang/srt/lora/lora_manager.py | 21 +++------------------ 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 4426faccba79..077fcad54c96 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -337,8 +337,12 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA + supported_layer_types = { # the order matters + FusedMoE: FusedMoEWithLoRA, VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, QKVParallelLinear: QKVParallelLinearWithLoRA, MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 0a39b26b907b..2e1239f32814 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -437,18 +437,11 @@ def init_memory_pool(self): ) def set_lora_module(self, module_name, module): + """Wrap any module (standard or MoE) with LoRA support.""" lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) return lora_module - def set_moe_lora_module(self, module_name, module): - """Wrap MoE module with LoRA support.""" - from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA - - lora_moe = FusedMoEWithLoRA(module, self.lora_backend) - replace_submodule(self.base_model, module_name, lora_moe) - return lora_moe - def init_lora_modules(self): # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ @@ -466,16 +459,8 @@ def init_lora_modules(self): ) and not self.base_model.should_apply_lora(module_name): continue - # Check if this is an MoE module first - from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE - - if isinstance(module, FusedMoE): - layer_id = get_layer_id(module_name) - self.lora_modules[layer_id][module_name] = self.set_moe_lora_module( - module_name, module - ) - # The module should be converted if it is included in target_names - elif module_name.split(".")[-1] in self.target_modules: + # Check if module should be wrapped with LoRA + if module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module From b218f23be5141b93a2af154e0dede14500fe56da Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 19 Nov 2025 03:27:41 +0000 Subject: [PATCH 025/150] fix --- python/sglang/srt/layers/moe/lora_moe.py | 18 +- python/sglang/srt/lora/lora_manager.py | 67 ++++ python/sglang/srt/lora/mem_pool.py | 94 ++++- python/sglang/srt/lora/moe_dispatch.py | 19 +- .../lora/triton_ops/per_expert_lora_moe.py | 360 +++++++++++------- python/sglang/srt/lora/utils.py | 5 +- .../srt/model_executor/forward_batch_info.py | 2 + .../sglang/srt/model_executor/model_runner.py | 1 - 8 files changed, 379 insertions(+), 187 deletions(-) diff --git a/python/sglang/srt/layers/moe/lora_moe.py b/python/sglang/srt/layers/moe/lora_moe.py index b37f3af52576..f4bf7886c7b9 100644 --- a/python/sglang/srt/layers/moe/lora_moe.py +++ b/python/sglang/srt/layers/moe/lora_moe.py @@ -81,9 +81,9 @@ def _compute_lora_delta( """ Compute LoRA delta using per-expert LoRA weights and add to base_output in-place. - Dispatch tokens to experts with LoRA-aware routing and compute per-expert deltas. + Dispatch tokens to experts and compute per-expert deltas. """ - from sglang.srt.lora.moe_dispatch import per_lora_moe_dispatch + from sglang.srt.lora.moe_dispatch import moe_dispatch from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( per_expert_lora_forward, ) @@ -94,24 +94,28 @@ def _compute_lora_delta( # Get LoRA batch info from backend batch_info = self.lora_backend.batch_info - weight_indices = batch_info.weight_indices # [num_tokens] lora_ranks = batch_info.lora_ranks # [num_loras] scalings = batch_info.scalings # [num_loras] + # Use precomputed per-token LoRA indices from forward batch + lora_indices = self.lora_backend.forward_batch.token_lora_indices + num_experts = self.base_moe.num_experts num_loras = self.lora_a_weights.shape[0] - # Dispatch tokens to (lora, expert) pairs - token_ids, expert_ids, _ = per_lora_moe_dispatch( + # Dispatch tokens to experts + token_ids, expert_ids, _ = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, - weight_indices=weight_indices, + lora_indices=lora_indices, num_experts=num_experts, num_loras=num_loras, ) # Get LoRA IDs for dispatched tokens - lora_ids = weight_indices[token_ids] + lora_ids = lora_indices[token_ids] + + # Compute per-expert LoRA forward (adds to base_output in-place) per_expert_lora_forward( diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 558204aa55ac..9704275d91e8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -40,6 +40,8 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_npu, replace_submodule from sglang.srt.utils.hf_transformers_utils import AutoConfig +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA if is_npu(): from torch_npu.contrib import transfer_to_npu # noqa: F401 @@ -300,25 +302,83 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): batch_info=self.cuda_graph_batch_info if use_cuda_graph else None, ) + # Populate per-token LoRA indices from segment information + batch_info = self.lora_backend.batch_info + num_tokens = forward_batch.batch_size + if batch_info.permutation is None: + # No reordering (e.g., triton backend): segments are in original order + token_lora_indices = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) + seg_indptr = batch_info.seg_indptr # [num_segments + 1] + for seg_idx in range(batch_info.num_segments): + start_token = seg_indptr[seg_idx] + end_token = seg_indptr[seg_idx + 1] + lora_adapter = batch_info.weight_indices[seg_idx] + token_lora_indices[start_token:end_token] = lora_adapter + else: + # Tokens are reordered (chunked backend): need to convert back to original order + token_lora_indices_reordered = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) + seg_indptr = batch_info.seg_indptr # [num_segments + 1] + for seg_idx in range(batch_info.num_segments): + start_token = seg_indptr[seg_idx] + end_token = seg_indptr[seg_idx + 1] + lora_adapter = batch_info.weight_indices[seg_idx] + token_lora_indices_reordered[start_token:end_token] = lora_adapter + + # Convert back to original token order using inverse permutation + inverse_permutation = torch.empty_like(batch_info.permutation) + inverse_permutation[batch_info.permutation] = torch.arange(num_tokens, dtype=batch_info.permutation.dtype, device=batch_info.permutation.device) + token_lora_indices = token_lora_indices_reordered[inverse_permutation] + + forward_batch.token_lora_indices = token_lora_indices + + # Store forward_batch reference in backend for MoE layer access + self.lora_backend.forward_batch = forward_batch + def update_lora_info(self): """ Update all LoRA modules to associate them with the latest memory buffer. """ for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): + # Hack for FusedMoE layer + if isinstance(module, FusedMoEWithLoRA) and all(x in self.target_modules for x in ['gate_up_proj', 'down_proj']): + module.set_lora_info( + self.memory_pool.get_tensor( + target_module='gate_up_proj', + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + context='moe', + ), + self.memory_pool.get_tensor( + target_module='down_proj', + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + context='moe', + ), + ) + continue + target_module = get_target_module_name( module_name, self.memory_pool.target_modules ) + + # Determine context based on module name + context = None + if isinstance(module, FusedMoEWithLoRA): + context = "moe" + module.set_lora_info( self.memory_pool.get_tensor( target_module=target_module, layer_id=layer_id, lora_type=LoRAType.LORA_A, + context=context, ), self.memory_pool.get_tensor( target_module=target_module, layer_id=layer_id, lora_type=LoRAType.LORA_B, + context=context, ), ) @@ -473,3 +533,10 @@ def init_lora_modules(self): self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module ) + continue + + # Temporarily workaround for FusedMoE layer + if isinstance(module, FusedMoE) and all(x in self.target_modules for x in ['gate_up_proj', 'down_proj']): + self.lora_modules[layer_id][module_name] = self.set_lora_module( + module_name, module + ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index c1bc454ae336..d68cf2b913f0 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -109,8 +109,17 @@ def _can_support(config: LoRAConfig) -> bool: def is_moe_module(self, module_name: str) -> bool: """Check if module is part of MoE experts.""" - moe_patterns = ["block_sparse_moe.experts", "experts.", "mlp.experts"] - return any(pattern in module_name for pattern in moe_patterns) + return "moe" in module_name + + def _get_standard_shape(self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int, layer_idx: int) -> Tuple[int]: + """Get 3D shape for standard (non-MoE) modules.""" + input_dim, _ = get_hidden_dim( + module_name, self.base_hf_config, base_model, layer_idx + ) + c = get_stacked_multiply(module_name) + if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: + input_dim = divide(input_dim, self.tp_size) + return (self.max_loras_per_batch, max_lora_dim * c, input_dim) def get_lora_A_shape( self, @@ -133,7 +142,7 @@ def get_lora_A_shape( if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: input_dim = divide(input_dim, self.tp_size) - # Check if MoE module and return appropriate shape + # Check if MoE module and return appropriate shape (the assumption is that down_proj and gate_up_proj are only used in MoE modules) if self.is_moe_module(module_name): num_experts = getattr( self.base_hf_config, @@ -183,20 +192,52 @@ def init_buffer( target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): + # Check if model has both shared experts and MoE experts + has_shared_experts = hasattr(base_model.config, 'shared_expert_intermediate_size') and \ + base_model.config.shared_expert_intermediate_size > 0 + has_moe = getattr(base_model.config, "num_experts", 1) > 1 + for module_name in target_modules: - buffer[module_name] = [ - torch.empty( - get_lora_shape_fn( - module_name, - base_model, - self.max_lora_rank, - idx, - ), - dtype=self.dtype, - device=device, - ) - for idx in range(self.num_layer) - ] + # Special handling for ambiguous target modules that can be in different contexts + ambiguous_modules = {"gate_up_proj", "down_proj"} + if module_name in ambiguous_modules and has_shared_experts and has_moe: + # Allocate separate buffers for shared and MoE contexts + # Shared expert version (3D) + shared_key = module_name + buffer[shared_key] = [ + torch.empty( + get_lora_shape_fn(module_name, base_model, self.max_lora_rank, idx), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] + + # MoE expert version (4D) + moe_key = f"{module_name}_moe" + buffer[moe_key] = [ + torch.empty( + get_lora_shape_fn(moe_key, base_model, self.max_lora_rank, idx), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] + else: + # Standard allocation for unambiguous modules + buffer[module_name] = [ + torch.empty( + get_lora_shape_fn( + module_name, + base_model, + self.max_lora_rank, + idx, + ), + dtype=self.dtype, + device=device, + ) + for idx in range(self.num_layer) + ] # Shape functions automatically handle both 3D (standard) and 4D (MoE) init_buffer( @@ -408,18 +449,31 @@ def load_lora_weight_tensor( load_lora_weight_tensor(buffer_view, weights) def get_tensor( - self, target_module: str, layer_id: int, lora_type: LoRAType + self, target_module: str, layer_id: int, lora_type: LoRAType, context: str = None ) -> torch.Tensor: """ Get LoRA tensor buffer (automatically handles both 3D and 4D tensors). + Args: + target_module: Target module name (e.g., 'gate_up_proj') + layer_id: Layer index + lora_type: LoRAType.LORA_A or LoRAType.LORA_B + context: Optional context hint ('moe' or None for auto-detect) + Returns: - 3D tensor [num_loras, rank, hidden] for standard modules - 4D tensor [num_loras, num_experts, rank, hidden] for MoE modules """ - if lora_type == LoRAType.LORA_A: - return self.A_buffer[target_module][layer_id] - return self.B_buffer[target_module][layer_id] + buffer_dict = self.A_buffer if lora_type == LoRAType.LORA_A else self.B_buffer + + # Handle context-specific buffer selection for ambiguous modules + ambiguous_modules = {"gate_up_proj", "down_proj"} + if target_module in ambiguous_modules: + if context == "moe" and f"{target_module}_moe" in buffer_dict: + return buffer_dict[f"{target_module}_moe"][layer_id] + + # Fall back to original key for non-ambiguous modules + return buffer_dict[target_module][layer_id] def get_buffer_id(self, lora_uid: str): return self.uid_to_buffer_id[lora_uid] diff --git a/python/sglang/srt/lora/moe_dispatch.py b/python/sglang/srt/lora/moe_dispatch.py index 3fc4c7a709d0..6a924b20b2f6 100644 --- a/python/sglang/srt/lora/moe_dispatch.py +++ b/python/sglang/srt/lora/moe_dispatch.py @@ -12,30 +12,30 @@ # limitations under the License. # ============================================================================== -"""MoE-specific LoRA dispatch utilities.""" +"""MoE dispatch utilities.""" import torch -def per_lora_moe_dispatch( +def moe_dispatch( topk_ids: torch.Tensor, topk_weights: torch.Tensor, - weight_indices: torch.Tensor, + lora_indices: torch.Tensor, num_experts: int, num_loras: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Dispatch tokens to experts with per-LoRA routing. + Dispatch tokens to experts for MoE computation. Args: topk_ids: [num_tokens, top_k] - Expert IDs selected by router topk_weights: [num_tokens, top_k] - Router weights - weight_indices: [num_tokens] - LoRA adapter ID for each token + lora_indices: [num_tokens] - LoRA adapter ID for each token num_experts: Total number of experts num_loras: Total number of LoRA adapters Returns: - sorted_token_ids: Token indices sorted by (lora_id, expert_id) + sorted_token_ids: Token indices sorted by expert_id sorted_expert_ids: Corresponding expert IDs sorted_weights: Corresponding router weights """ @@ -46,12 +46,11 @@ def per_lora_moe_dispatch( flat_topk_ids = topk_ids.flatten() flat_topk_weights = topk_weights.flatten() flat_token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k) - flat_lora_ids = weight_indices.repeat_interleave(top_k) - # Create composite key for sorting: lora_id * num_experts + expert_id - composite_key = flat_lora_ids * num_experts + flat_topk_ids + # Sort by expert_id only (each expert uses same LoRA adapter logic) + composite_key = flat_topk_ids - # Sort by composite key to group by (lora_id, expert_id) + # Sort by expert_id to group tokens by expert sorted_indices = torch.argsort(composite_key) sorted_token_ids = flat_token_ids[sorted_indices] diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 9132d0200fcf..92553f9ed07a 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -22,146 +22,190 @@ @triton.jit def _per_expert_lora_kernel( # Input/Output pointers - hidden_states_ptr, - lora_a_weights_ptr, - lora_b_weights_ptr, - output_ptr, - # Dispatch info - token_ids_ptr, - expert_ids_ptr, - lora_ids_ptr, + hidden_states_ptr, # [num_total_tokens, hidden_dim] + lora_a_weights_ptr, # [num_loras, num_experts, max_rank, hidden_dim] + lora_b_weights_ptr, # [num_loras, num_experts, intermediate_dim, max_rank] + output_ptr, # [num_total_tokens, intermediate_dim] + + # Dispatch info (length = num_dispatched) + token_ids_ptr, # [num_dispatched] -> index into hidden/output + expert_ids_ptr, # [num_dispatched] + lora_ids_ptr, # [num_dispatched] + # Dimensions hidden_dim: tl.constexpr, intermediate_dim: tl.constexpr, max_rank: tl.constexpr, num_experts: tl.constexpr, - num_tokens: tl.constexpr, - # Strides for 4D LoRA weights [num_loras, num_experts, *, *] + num_dispatched, + + # Strides for 4D LoRA A weights [num_loras, num_experts, max_rank, hidden_dim] lora_a_stride_lora: tl.constexpr, lora_a_stride_expert: tl.constexpr, lora_a_stride_rank: tl.constexpr, lora_a_stride_hidden: tl.constexpr, + + # Strides for 4D LoRA B weights [num_loras, num_experts, intermediate_dim, max_rank] lora_b_stride_lora: tl.constexpr, lora_b_stride_expert: tl.constexpr, lora_b_stride_intermediate: tl.constexpr, lora_b_stride_rank: tl.constexpr, - # LoRA ranks per adapter + + # LoRA ranks per adapter [num_loras] lora_ranks_ptr, - # Scaling factors per adapter + # Scaling factors per adapter [num_loras] lora_scalings_ptr, - # Block sizes - BLOCK_HIDDEN: tl.constexpr, - BLOCK_INTERMEDIATE: tl.constexpr, - BLOCK_RANK: tl.constexpr, + + # Block size (used for hidden and output tiling; rank is not tiled) + BLOCK_SIZE: tl.constexpr, ): """ - Compute per-expert LoRA delta: delta = B @ A @ hidden_states. + Compute per-expert LoRA delta: - Grid: (spatial_tiles, intermediate_slices, num_loras) - - spatial_tiles: Number of token tiles - - intermediate_slices: Number of output dimension tiles - - num_loras: Process each LoRA adapter in parallel + delta[token, out_slice, lora] = B[out_slice, :] @ (A @ hidden_states[token]) + + 3D Grid: (spatial, slices, loras) + - spatial = program_id(0): dispatched token index + - slices = program_id(1): tile index along intermediate_dim + - loras = program_id(2): LoRA adapter index """ - # Grid IDs - token_tile_id = tl.program_id(0) - output_tile_id = tl.program_id(1) - lora_id = tl.program_id(2) - # Get rank and scaling for this LoRA adapter - rank = tl.load(lora_ranks_ptr + lora_id) - scaling = tl.load(lora_scalings_ptr + lora_id) + # 3D grid indices + spatial_id = tl.program_id(0) # dispatched token index + slice_id = tl.program_id(1) # output slice index + lora_id_grid = tl.program_id(2) # LoRA adapter index - # Early exit if rank is 0 - if rank == 0: + # Bounds check on dispatched tokens + if spatial_id >= num_dispatched: return - # Token range for this tile - token_start = token_tile_id * BLOCK_HIDDEN - token_end = tl.minimum(token_start + BLOCK_HIDDEN, num_tokens) - - # Output dimension range - out_start = output_tile_id * BLOCK_INTERMEDIATE - out_end = tl.minimum(out_start + BLOCK_INTERMEDIATE, intermediate_dim) - - # Process each token in this tile - for token_idx in range(token_start, token_end): - if token_idx >= num_tokens: - break - - # Load dispatch info for this token - actual_token_id = tl.load(token_ids_ptr + token_idx) - expert_id = tl.load(expert_ids_ptr + token_idx) - token_lora_id = tl.load(lora_ids_ptr + token_idx) - - # Skip if this token doesn't belong to current LoRA - if token_lora_id != lora_id: - continue - - # Load hidden states for this token: [hidden_dim] - hidden_ptr = hidden_states_ptr + actual_token_id * hidden_dim - hidden_offs = tl.arange(0, BLOCK_HIDDEN) - hidden_mask = hidden_offs < hidden_dim - hidden = tl.load(hidden_ptr + hidden_offs, mask=hidden_mask, other=0.0) - - # Compute A @ hidden: [rank] = [rank, hidden_dim] @ [hidden_dim] - intermediate_a = tl.zeros([BLOCK_RANK], dtype=tl.float32) - - for k_tile in range(0, tl.cdiv(hidden_dim, BLOCK_HIDDEN)): - k_start = k_tile * BLOCK_HIDDEN - k_offs = tl.arange(0, BLOCK_HIDDEN) + k_start - k_mask = k_offs < hidden_dim - - # Load from hidden states - h_vals = tl.load(hidden_ptr + k_offs, mask=k_mask, other=0.0) - - # Load LoRA A weights: [rank, hidden_dim] - for r in range(BLOCK_RANK): - if r >= rank: - break - lora_a_offset = ( - lora_id * lora_a_stride_lora - + expert_id * lora_a_stride_expert - + r * lora_a_stride_rank - + k_start * lora_a_stride_hidden - ) - a_vals = tl.load( - lora_a_weights_ptr + lora_a_offset + k_offs, - mask=k_mask, - other=0.0, - ) - intermediate_a[r] += tl.sum(a_vals * h_vals) - - # Compute B @ intermediate_a: [intermediate_dim] = [intermediate_dim, rank] @ [rank] - out_offs = tl.arange(0, BLOCK_INTERMEDIATE) + out_start - out_mask = out_offs < intermediate_dim - - output_vals = tl.zeros([BLOCK_INTERMEDIATE], dtype=tl.float32) - - for r in range(BLOCK_RANK): - if r >= rank: - break - - # Load LoRA B weights: [intermediate_dim, rank] - lora_b_offset = ( - lora_id * lora_b_stride_lora - + expert_id * lora_b_stride_expert - + out_start * lora_b_stride_intermediate - + r * lora_b_stride_rank - ) - b_vals = tl.load( - lora_b_weights_ptr - + lora_b_offset - + out_offs * lora_b_stride_intermediate, - mask=out_mask, - other=0.0, - ) - output_vals += b_vals * intermediate_a[r] - - # Scale and accumulate to output - output_vals *= scaling - output_offset = actual_token_id * intermediate_dim + out_start - tl.atomic_add(output_ptr + output_offset + out_offs, output_vals, mask=out_mask) + # Load dispatch info for this dispatched index + actual_token_id = tl.load(token_ids_ptr + spatial_id) + expert_id = tl.load(expert_ids_ptr + spatial_id) + token_lora_id = tl.load(lora_ids_ptr + spatial_id) + + # Skip if this token does not use this LoRA adapter + if token_lora_id != lora_id_grid: + return + + # Load LoRA rank and scaling (scalar tensors) for this LoRA adapter + rank = tl.load(lora_ranks_ptr + lora_id_grid) + scaling = tl.load(lora_scalings_ptr + lora_id_grid) + has_rank = rank > 0 + if not has_rank: + return + # ---------------------------- + # Base pointers + # ---------------------------- + # hidden_states[actual_token_id, :] + hidden_ptr = hidden_states_ptr + actual_token_id * hidden_dim + + # A[lora_id_grid, expert_id, :, :] + lora_a_base = ( + lora_a_weights_ptr + + lora_id_grid * lora_a_stride_lora + + expert_id * lora_a_stride_expert + ) + + # B[lora_id_grid, expert_id, :, :] + lora_b_base = ( + lora_b_weights_ptr + + lora_id_grid * lora_b_stride_lora + + expert_id * lora_b_stride_expert + ) + + # ---------------------------- + # Stage 1: intermediate = A @ hidden + # ---------------------------- + + # We assume max_rank is small enough to keep as a single 1D vector + r_offs = tl.arange(0, max_rank) # [max_rank] + rank_mask = r_offs < rank # [max_rank] + + # Accumulator for intermediate: [max_rank] + intermediate = tl.zeros((max_rank,), dtype=tl.float32) + + # Tile over hidden_dim in chunks of BLOCK_SIZE + NUM_HIDDEN_TILES = (hidden_dim + BLOCK_SIZE - 1) // BLOCK_SIZE + for hidden_tile_idx in range(NUM_HIDDEN_TILES): + hidden_start = hidden_tile_idx * BLOCK_SIZE + hidden_offs = hidden_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] + hidden_mask = hidden_offs < hidden_dim # [BLOCK_SIZE] + + # Load hidden values for this tile: [BLOCK_SIZE] + h_vals = tl.load( + hidden_ptr + hidden_offs, + mask=hidden_mask, + other=0.0, + ).to(tl.float32) + + # Build [max_rank, BLOCK_SIZE] tile of A: + # rows: r_offs + # cols: hidden_offs + # offset = base + r * stride_rank + h * stride_hidden + a_ptrs = ( + lora_a_base + + r_offs[:, None] * lora_a_stride_rank + + hidden_offs[None, :] * lora_a_stride_hidden + ) + a_vals = tl.load( + a_ptrs, + mask=rank_mask[:, None] & hidden_mask[None, :], + other=0.0, + ).to(tl.float32) + + # Dot over hidden axis: [max_rank] + # intermediate[r] += sum_h A[r, h] * h_vals[h] + intermediate += tl.sum(a_vals * h_vals[None, :], axis=1) + + # ---------------------------- + # Stage 2: y_slice = B[out_slice, :] @ intermediate + # One output slice per program along intermediate_dim. + # ---------------------------- + out_start = slice_id * BLOCK_SIZE + out_offs = out_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] + out_mask = out_offs < intermediate_dim # [BLOCK_SIZE] + + # If this slice is entirely out of bounds, we can early-exit + # (not strictly necessary but cheap) + # NOTE: Triton doesn't have a direct "if not any(mask)" primitive, + # but the mask will naturally guard loads/stores below, so this is safe to omit. + # We'll just rely on masks. + + # Build [max_rank, BLOCK_SIZE] tile of B: + # rows: r_offs (rank dimension) + # cols: out_offs (output dimension) + # offset = base + out * stride_intermediate + r * stride_rank + b_ptrs = ( + lora_b_base + + out_offs[None, :] * lora_b_stride_intermediate + + r_offs[:, None] * lora_b_stride_rank + ) + b_vals = tl.load( + b_ptrs, + mask=rank_mask[:, None] & out_mask[None, :], + other=0.0, + ).to(tl.float32) + + # Contribution: + # out_vals[j] = sum_r B[j, r] * intermediate[r] + out_vals = tl.sum(b_vals * intermediate[:, None], axis=0) # [BLOCK_SIZE] + + # Apply scaling + out_vals *= scaling + + # ---------------------------- + # Accumulate into global output + # ---------------------------- + out_row_base = actual_token_id * intermediate_dim + out_ptrs = output_ptr + out_row_base + out_offs + + tl.atomic_add( + out_ptrs, + out_vals, + mask=out_mask & has_rank, + ) def per_expert_lora_forward( hidden_states: torch.Tensor, @@ -176,7 +220,8 @@ def per_expert_lora_forward( base_output: torch.Tensor = None, ) -> torch.Tensor: """ - Forward pass for per-expert LoRA computation. + Forward pass for per-expert LoRA computation using a 3D Triton grid: + grid = (spatial, slices, loras) Args: hidden_states: [num_tokens, hidden_dim] @@ -193,61 +238,82 @@ def per_expert_lora_forward( Returns: output: [num_tokens, intermediate_dim] - Base output + LoRA delta (in-place) """ + # Shapes num_tokens, hidden_dim = hidden_states.shape num_loras, _, intermediate_dim, max_rank = lora_b_weights.shape num_dispatched = token_ids.shape[0] + # Make sure everything is on the same device and contiguous + device = hidden_states.device + hidden_states = hidden_states.contiguous() + lora_a_weights = lora_a_weights.contiguous() + lora_b_weights = lora_b_weights.contiguous() + token_ids = token_ids.contiguous() + expert_ids = expert_ids.contiguous() + lora_ids = lora_ids.contiguous() + lora_ranks = lora_ranks.contiguous() + lora_scalings = lora_scalings.contiguous() + # Initialize or reuse output tensor for in-place addition if base_output is None: + # Use float32 for accumulation; you can cast back if needed output = torch.zeros( num_tokens, - intermediate_dim, - dtype=hidden_states.dtype, - device=hidden_states.device, + hidden_dim, + dtype=torch.float32, + device=device, ) else: output = base_output + assert output.shape == (num_tokens, hidden_dim) # TODO (jonahcb): check if this is correct + assert output.device == device - # Block sizes (tuned for typical dimensions) - BLOCK_HIDDEN = 128 - BLOCK_INTERMEDIATE = 128 - BLOCK_RANK = 64 + # Tile size for hidden and output dimensions + BLOCK_SIZE = 64 # tune as needed - # Grid dimensions: (spatial_tiles, intermediate_slices, num_loras) - grid = ( - triton.cdiv(num_dispatched, BLOCK_HIDDEN), - triton.cdiv(intermediate_dim, BLOCK_INTERMEDIATE), - num_loras, - ) + # Number of output slices along intermediate_dim + num_slices = (intermediate_dim + BLOCK_SIZE - 1) // BLOCK_SIZE + + # 3D grid: (spatial, slices, loras) + grid = (num_dispatched, num_slices, num_loras) _per_expert_lora_kernel[grid]( - hidden_states, - lora_a_weights, - lora_b_weights, - output, - token_ids, - expert_ids, - lora_ids, - hidden_dim, - intermediate_dim, - max_rank, - num_experts, - num_dispatched, + # Pointers + hidden_states, # hidden_states_ptr + lora_a_weights, # lora_a_weights_ptr + lora_b_weights, # lora_b_weights_ptr + output, # output_ptr + + # Dispatch info + token_ids, # token_ids_ptr + expert_ids, # expert_ids_ptr + lora_ids, # lora_ids_ptr + + # Dimensions + hidden_dim, # hidden_dim + intermediate_dim, # intermediate_dim + max_rank, # max_rank + num_experts, # num_experts + num_dispatched, # num_dispatched (runtime scalar) + # LoRA A strides: [num_loras, num_experts, max_rank, hidden_dim] lora_a_weights.stride(0), lora_a_weights.stride(1), lora_a_weights.stride(2), lora_a_weights.stride(3), + # LoRA B strides: [num_loras, num_experts, intermediate_dim, max_rank] lora_b_weights.stride(0), lora_b_weights.stride(1), lora_b_weights.stride(2), lora_b_weights.stride(3), - lora_ranks, - lora_scalings, - BLOCK_HIDDEN, - BLOCK_INTERMEDIATE, - BLOCK_RANK, + + # Rank & scaling + lora_ranks, # lora_ranks_ptr + lora_scalings, # lora_scalings_ptr + + # Block size (constexpr) + BLOCK_SIZE=BLOCK_SIZE, ) return output diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 7037fc4a686c..aad89c371cd1 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -85,9 +85,9 @@ def get_hidden_dim( head_dim * config.num_attention_heads, config.hidden_size, ) - elif module_name == "gate_up_proj": + elif module_name == "gate_up_proj" or module_name == "gate_up_proj_moe": return config.hidden_size, config.intermediate_size * 2 - elif module_name == "down_proj": + elif module_name == "down_proj" or module_name == "down_proj_moe": return config.intermediate_size, config.hidden_size else: raise NotImplementedError() @@ -123,6 +123,7 @@ def get_stacked_multiply(module_name: str) -> int: stacked_rank = { "qkv_proj": 3, "gate_up_proj": 2, + "gate_up_proj_moe": 2, } return stacked_rank[module_name] if module_name in stacked_rank else 1 diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bd4f6121c841..b0564032d746 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -282,6 +282,8 @@ class ForwardBatch: # For LoRA lora_ids: Optional[List[str]] = None + # Per-token LoRA adapter indices (expanded from lora_ids) + token_lora_indices: Optional[torch.Tensor] = None # For input embeddings input_embeds: Optional[torch.Tensor] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fb3dbd4bedf2..32f32a0b4952 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1584,7 +1584,6 @@ def init_memory_pool( ) log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}") - self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) if SGLANG_CI_SMALL_KV_SIZE: self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE) From ba64942d862806377d901eaa5e5418ec7c0f7e4f Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 19 Nov 2025 03:46:32 +0000 Subject: [PATCH 026/150] add test --- test/srt/lora/test_lora_moe.py | 478 +++++++++++++++++++++++++++++++++ 1 file changed, 478 insertions(+) create mode 100644 test/srt/lora/test_lora_moe.py diff --git a/test/srt/lora/test_lora_moe.py b/test/srt/lora/test_lora_moe.py new file mode 100644 index 000000000000..6cd714ac722a --- /dev/null +++ b/test/srt/lora/test_lora_moe.py @@ -0,0 +1,478 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Test MoE LoRA implementation by comparing against HuggingFace. + +This test file verifies that SGLang's MoE LoRA implementation produces the same +outputs as HuggingFace's PEFT library for MoE models. + +WORKFLOW: +1. Run SGLang test in first session (loads SGLang model, saves results to temp file) +2. Run HF test in second session (loads HF model, saves results to temp file) +3. Compare saved results from both sessions + +This avoids loading both models simultaneously, saving GPU memory. + +Usage: + # Run basic functionality test (no HF comparison) + python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_basic_functionality -v + + # Run full HF comparison test (requires model and LoRA adapter) + # This will run SGLang and HF tests separately, then compare results + python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_qwen15 -v + +Manual Testing (run separately): + # Terminal 1: Run SGLang test and save results + python -c " + from test.srt.lora.test_lora_moe import TestMoELoRA + import torch + import tempfile + import os + + test = TestMoELoRA() + model_case = test.MOE_LORA_TEST_CASES[0] + + with tempfile.TemporaryDirectory() as temp_dir: + result_file = os.path.join(temp_dir, 'srt_results.pkl') + test._run_srt_test(result_file, model_case, torch.float16) + print(f'SGLang results saved to {result_file}') + input('Press Enter after copying result file...') + " + + # Terminal 2: Run HF test and save results + python -c " + from test.srt.lora.test_lora_moe import TestMoELoRA + import torch + import tempfile + import os + + test = TestMoELoRA() + model_case = test.MOE_LORA_TEST_CASES[0] + + with tempfile.TemporaryDirectory() as temp_dir: + result_file = os.path.join(temp_dir, 'hf_results.pkl') + test._run_hf_test(result_file, model_case, torch.float16) + print(f'HF results saved to {result_file}') + input('Press Enter after copying result file...') + " + +Requirements: + - Qwen/Qwen1.5-MoE-A2.7B model + - sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest LoRA adapter + - Sufficient GPU memory (tests run in separate sessions) + - Uses same memory configuration as launch.json for compatibility +""" + +import json +import multiprocessing as mp +import os +import pickle +import random +import tempfile +import time +import unittest +from pathlib import Path + +from utils import LoRAModelCase, LoRAAdaptor, ensure_reproducibility + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci + +# Test prompts for MoE LoRA comparison +TEST_MOE_PROMPTS = [ + "The capital of France is Paris. The capital of", + "Explain what mixture of experts means in machine learning.", + "Write a short poem about artificial intelligence and large language models.", + "What are the benefits of using MoE architectures in transformers?", +] + +# MoE model test cases with LoRA adapters +MOE_LORA_TEST_CASES = [ + LoRAModelCase( + base="Qwen/Qwen1.5-MoE-A2.7B", + adaptors=[ + # Use a real LoRA adapter path - replace with actual path when testing + LoRAAdaptor( + name="sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest", + prefill_tolerance=1e-1, + decode_tolerance=1e-1, + rouge_l_tolerance=1.0 + ), + ], + tp_size=1, + prefill_tolerance=1e-1, + decode_tolerance=1e-1, + rouge_l_tolerance=1.0, + max_loras_per_batch=1, + ), + # Add more MoE models here when available +] + + +class TestMoELoRA(CustomTestCase): + """Test MoE LoRA implementation by comparing against HuggingFace.""" + + def _get_srt_server_args(self): + """Get SGLang server arguments from launch.json config.""" + return { + "quantization": "fp8", + "disable_radix_cache": True, + "lora_backend": "csgmv", + "max_lora_chunk_size": 16, + "port": 30000, + "host": "127.0.0.1", + "max_loras_per_batch": 1, + "tp_size": 2, + "max_total_tokens": 128, + "page_size": 64, + "max_running_requests": 1, + "mem_fraction_static": 0.85, + } + + def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): + """Run SGLang test and save results to file.""" + try: + base_path = model_case.base + adaptor_names = [adaptor.name for adaptor in model_case.adaptors] + + print(f"\n========== Running SGLang MoE LoRA test on '{base_path}', dtype={torch_dtype} ==========") + + server_args = self._get_srt_server_args() + server_args["model_path"] = base_path + server_args["lora_paths"] = adaptor_names + server_args["enable_lora"] = True + + # Initialize SGLang runner with launch.json args + srt_runner = SRTRunner( + model_path=base_path, + torch_dtype=torch_dtype, + model_type="generation", + tp_size=server_args["tp_size"], + lora_paths=adaptor_names, + max_loras_per_batch=server_args["max_loras_per_batch"], + lora_backend=server_args["lora_backend"], + disable_cuda_graph=False, + disable_radix_cache=server_args["disable_radix_cache"], + max_total_tokens=server_args["max_total_tokens"], + page_size=server_args["page_size"], + mem_fraction_static=server_args["mem_fraction_static"], + sleep_on_idle=True, + # Pass additional args via json_model_override_args + json_model_override_args={ + "quantization": server_args["quantization"], + "max_lora_chunk_size": server_args["max_lora_chunk_size"], + "max_running_requests": server_args["max_running_requests"], + } + ) + + results = {} + + # Test with different batch configurations + test_configs = [ + {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA + {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA + ] + + with srt_runner: + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + + # Use fixed prompts for reproducibility + prompts = TEST_MOE_PROMPTS[:batch_size] + + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + print(f"\n--- SRT Testing {config_key} ---") + print(f"Prompts: {prompts}") + + # Ensure reproducibility + ensure_reproducibility() + + # Run SGLang + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + results[config_key] = { + "prompts": prompts, + "lora_paths": lora_paths, + "srt_outputs": { + "output_strs": srt_outputs.output_strs, + "top_input_logprobs": srt_outputs.top_input_logprobs, + "top_output_logprobs": srt_outputs.top_output_logprobs, + } + } + + # Save results + with open(result_file, 'wb') as f: + pickle.dump(results, f) + print(f"SGLang results saved to {result_file}") + + except Exception as e: + print(f"SGLang test failed: {e}") + with open(result_file, 'wb') as f: + pickle.dump({"error": str(e)}, f) + + def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): + """Run HuggingFace test and save results to file.""" + try: + base_path = model_case.base + adaptor_names = [adaptor.name for adaptor in model_case.adaptors] + + print(f"\n========== Running HF MoE LoRA test on '{base_path}', dtype={torch_dtype} ==========") + + # Initialize HF runner + hf_runner = HFRunner( + base_path, torch_dtype=torch_dtype, model_type="generation" + ) + + results = {} + + # Test with different batch configurations + test_configs = [ + {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA + {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA + ] + + with hf_runner: + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + + # Use fixed prompts for reproducibility + prompts = TEST_MOE_PROMPTS[:batch_size] + + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + print(f"\n--- HF Testing {config_key} ---") + print(f"Prompts: {prompts}") + + # Ensure reproducibility + ensure_reproducibility() + + # Run HuggingFace + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + results[config_key] = { + "prompts": prompts, + "lora_paths": lora_paths, + "hf_outputs": { + "output_strs": hf_outputs.output_strs, + "top_input_logprobs": hf_outputs.top_input_logprobs, + "top_output_logprobs": hf_outputs.top_output_logprobs, + } + } + + # Save results + with open(result_file, 'wb') as f: + pickle.dump(results, f) + print(f"HF results saved to {result_file}") + + except Exception as e: + print(f"HF test failed: {e}") + with open(result_file, 'wb') as f: + pickle.dump({"error": str(e)}, f) + + def _run_moe_lora_comparison(self, model_case: LoRAModelCase, torch_dtype, max_new_tokens=32): + """Run LoRA comparison test by loading saved results from separate sessions.""" + base_path = model_case.base + + # Create temp directory for results + with tempfile.TemporaryDirectory() as temp_dir: + srt_result_file = os.path.join(temp_dir, 'srt_results.pkl') + hf_result_file = os.path.join(temp_dir, 'hf_results.pkl') + + print(f"\n========== Testing MoE LoRA on '{base_path}', dtype={torch_dtype} ==========") + + # Check if results already exist + srt_done = os.path.exists(srt_result_file) + hf_done = os.path.exists(hf_result_file) + + if not srt_done: + print("Running SGLang test...") + self._run_srt_test(srt_result_file, model_case, torch_dtype, max_new_tokens) + srt_done = True + + if not hf_done: + print("Running HuggingFace test...") + self._run_hf_test(hf_result_file, model_case, torch_dtype, max_new_tokens) + hf_done = True + + # Load results + if srt_done: + with open(srt_result_file, 'rb') as f: + srt_results = pickle.load(f) + if "error" in srt_results: + self.fail(f"SGLang test failed: {srt_results['error']}") + + if hf_done: + with open(hf_result_file, 'rb') as f: + hf_results = pickle.load(f) + if "error" in hf_results: + self.fail(f"HF test failed: {hf_results['error']}") + + # Compare results + if srt_done and hf_done: + for config_key in srt_results.keys(): + if config_key in hf_results: + srt_data = srt_results[config_key] + hf_data = hf_results[config_key] + + self._compare_outputs( + srt_data["srt_outputs"], hf_data["hf_outputs"], + model_case, srt_data["prompts"], srt_data["lora_paths"] + ) + + def _compare_outputs(self, srt_outputs, hf_outputs, model_case, prompts, lora_paths): + """Compare SGLang and HF outputs.""" + for i, (prompt, lora_path) in enumerate(zip(prompts, lora_paths)): + print(f"\nRequest {i}: lora_path='{lora_path}'") + print(f"Prompt: {prompt[:50]}{'...' if len(prompt) > 50 else ''}") + + # Compare output strings + srt_output_str = srt_outputs.output_strs[i].strip() + hf_output_str = hf_outputs.output_strs[i].strip() + + print(f"SRT output: {srt_output_str[:100]}{'...' if len(srt_output_str) > 100 else ''}") + print(f"HF output: {hf_output_str[:100]}{'...' if len(hf_output_str) > 100 else ''}") + + # Calculate ROUGE-L similarity + rouge_l = calculate_rouge_l(srt_output_str, hf_output_str) + print(f"ROUGE-L similarity: {rouge_l:.4f}") + + # Check ROUGE-L tolerance + self.assertGreaterEqual( + rouge_l, + model_case.rouge_l_tolerance, + f"ROUGE-L similarity {rouge_l:.4f} below tolerance {model_case.rouge_l_tolerance} " + f"for request {i} with lora_path='{lora_path}'" + ) + + # Compare logprobs if available + if hasattr(srt_outputs, 'top_input_logprobs') and hasattr(hf_outputs, 'top_input_logprobs'): + if srt_outputs.top_input_logprobs[i] is not None and hf_outputs.top_input_logprobs[i] is not None: + import torch + srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i]) + hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i]) + + max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill)) + print(f"Max prefill logprob diff: {max_prefill_diff:.6f}") + + # Check prefill tolerance + prefill_tol = model_case.prefill_tolerance + self.assertLessEqual( + max_prefill_diff, + prefill_tol, + f"Prefill logprob diff {max_prefill_diff:.6f} exceeds tolerance {prefill_tol} " + f"for request {i} with lora_path='{lora_path}'" + ) + + if hasattr(srt_outputs, 'top_output_logprobs') and hasattr(hf_outputs, 'top_output_logprobs'): + if srt_outputs.top_output_logprobs[i] is not None and hf_outputs.top_output_logprobs[i] is not None: + import torch + srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i]) + hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i]) + + max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode)) + print(f"Max decode logprob diff: {max_decode_diff:.6f}") + + # Check decode tolerance + decode_tol = model_case.decode_tolerance + self.assertLessEqual( + max_decode_diff, + decode_tol, + f"Decode logprob diff {max_decode_diff:.6f} exceeds tolerance {decode_tol} " + f"for request {i} with lora_path='{lora_path}'" + ) + + def test_moe_lora_qwen15(self): + """Test LoRA on Qwen1.5-MoE-A2.7B.""" + if is_in_ci(): + self.skipTest("Skipping MoE LoRA test in CI environment") + + model_case = MOE_LORA_TEST_CASES[0] + + # Test with different dtypes + import torch + for torch_dtype in [torch.float16, torch.bfloat16]: + with self.subTest(dtype=torch_dtype): + try: + self._run_moe_lora_comparison(model_case, torch_dtype) + except Exception as e: + self.fail(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") + + def test_moe_lora_basic_functionality(self): + """Basic functionality test for MoE LoRA dispatch.""" + # This test focuses on the core dispatch logic without full HF comparison + # Useful for debugging the MoE LoRA implementation + + import torch + from sglang.srt.lora.moe_dispatch import moe_dispatch + + # Create test data + num_tokens = 4 + top_k = 2 + num_experts = 8 + + # Mock top-k routing results + topk_ids = torch.tensor([ + [0, 1], # token 0 routes to experts 0, 1 + [2, 3], # token 1 routes to experts 2, 3 + [1, 4], # token 2 routes to experts 1, 4 + [5, 6], # token 3 routes to experts 5, 6 + ], dtype=torch.int32) + + topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) + + # Mock LoRA indices (one per token) + lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) # tokens 0,1 use lora 0; tokens 2,3 use lora 1 + + # Run dispatch + token_ids, expert_ids, weights = moe_dispatch( + topk_ids=topk_ids, + topk_weights=topk_weights, + lora_indices=lora_indices, + num_experts=num_experts, + num_loras=2, + ) + + # Verify results + # Should have 4 tokens * 2 experts each = 8 dispatched entries + self.assertEqual(len(token_ids), 8) + self.assertEqual(len(expert_ids), 8) + self.assertEqual(len(weights), 8) + + # Check that tokens are grouped by expert (not by LoRA) + # All tokens going to expert 0 should come first, then expert 1, etc. + unique_experts, expert_counts = torch.unique_consecutive(expert_ids, return_counts=True) + self.assertTrue(torch.all(expert_counts >= 1)) # Each expert should have at least one token + + print(f"Dispatch successful: {len(token_ids)} dispatched tokens to experts {unique_experts.tolist()}") + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main() From c339cee4f3d306bb721c5fed7413b80352c19e88 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 19 Nov 2025 20:14:47 +0000 Subject: [PATCH 027/150] fix --- python/sglang/srt/lora/mem_pool.py | 31 +++ python/sglang/test/runners.py | 80 +++++- test/srt/lora/test_lora_moe.py | 415 ++++++++++++++++++++++++++--- 3 files changed, 487 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index d68cf2b913f0..76985439e24e 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -383,6 +383,37 @@ def load_lora_weight_tensor( if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] for module_name, module in cur_layer_modules.items(): + # TODO (Jonahcb): check if the code can be refactored to avoid the special handling for FusedMoEWithLoRA + # Handle FusedMoEWithLoRA specially - it contains multiple target modules + from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA + if isinstance(module, FusedMoEWithLoRA): + # FusedMoEWithLoRA contains both gate_up_proj and down_proj + moe_target_modules = ['gate_up_proj_moe', 'down_proj_moe'] + for target_module in moe_target_modules: + + if temp_A_buffer[target_module] is None: + # Skip weight slicing if the weight is not present in the adapter + continue + + # Handle MoE modules (they contain dicts of per-expert tensors) + # Slice each expert's weights individually + for expert_id in temp_A_buffer[target_module].keys(): + temp_A_buffer[target_module][expert_id] = ( + module.slice_lora_a_weights( + temp_A_buffer[target_module][expert_id], + self.tp_rank, + ) + ) + temp_B_buffer[target_module][expert_id] = ( + module.slice_lora_b_weights( + temp_B_buffer[target_module][expert_id], + self.tp_rank, + ) + ) + + continue + + # Handle regular modules target_module = get_target_module_name( module_name, self.target_modules ) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e469a3c035a6..9fa98246affe 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -16,7 +16,7 @@ import multiprocessing as mp import os from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -142,11 +142,31 @@ def __init__( trust_remote_code: bool = False, patch_model_do_sample_false: bool = False, matryoshka_dim: Optional[int] = None, + # Memory and device management (matching SRTRunner capabilities) + max_memory: Optional[Dict[str, str]] = None, + device_map: Optional[str] = None, + load_in_8bit: bool = False, + load_in_4bit: bool = False, + offload_folder: Optional[str] = None, + # Additional parameters for consistency with SRTRunner + mem_fraction_static: Optional[float] = None, + max_total_tokens: Optional[int] = None, + max_running_requests: Optional[int] = None, + quantization: Optional[str] = None, ): self.model_type = model_type self.output_str_only = output_str_only self.trust_remote_code = trust_remote_code self.patch_model_do_sample_false = patch_model_do_sample_false + self.max_memory = max_memory + self.device_map = device_map + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.offload_folder = offload_folder + self.mem_fraction_static = mem_fraction_static + self.max_total_tokens = max_total_tokens + self.max_running_requests = max_running_requests + self.quantization = quantization self.in_queue = mp.Queue() self.out_queue = mp.Queue() @@ -159,6 +179,12 @@ def __init__( model_path, torch_dtype, matryoshka_dim, + self.max_memory, + self.device_map, + self.load_in_8bit, + self.load_in_4bit, + self.offload_folder, + self.quantization, ), ) self.model_proc.start() @@ -240,10 +266,41 @@ def start_model_process( model_path, torch_dtype, matryoshka_dim: Optional[int] = None, + max_memory=None, + device_map=None, + load_in_8bit=False, + load_in_4bit=False, + offload_folder=None, + quantization=None, ): # Apply model-specific patches monkey_patch_gemma2_sdpa() + # Build common kwargs for from_pretrained calls + common_kwargs = { + "torch_dtype": torch_dtype, + "trust_remote_code": self.trust_remote_code, + "low_cpu_mem_usage": True, + } + + # Add optional memory/device parameters + if max_memory is not None: + common_kwargs["max_memory"] = max_memory + if device_map is not None: + common_kwargs["device_map"] = device_map + if load_in_8bit: + common_kwargs["load_in_8bit"] = load_in_8bit + if load_in_4bit: + common_kwargs["load_in_4bit"] = load_in_4bit + if offload_folder is not None: + common_kwargs["offload_folder"] = offload_folder + + # Handle quantization parameter (map to appropriate HF parameter) + if quantization == "8bit": + common_kwargs["load_in_8bit"] = True + elif quantization == "4bit": + common_kwargs["load_in_4bit"] = True + # Load the model and tokenizer if self.model_type == "generation": config = AutoConfig.from_pretrained( @@ -256,17 +313,15 @@ def start_model_process( model_cls = getattr(transformers, model_arch) self.base_model = model_cls.from_pretrained( model_path, - torch_dtype=torch_dtype, - trust_remote_code=self.trust_remote_code, - low_cpu_mem_usage=True, + **common_kwargs ).cuda() elif self.model_type == "embedding": if "gme-qwen2-vl" in model_path.lower(): + embedding_kwargs = common_kwargs.copy() + embedding_kwargs["trust_remote_code"] = False # Override for this specific model self.model = AutoModelForVision2Seq.from_pretrained( model_path, - torch_dtype=torch_dtype, - trust_remote_code=False, - low_cpu_mem_usage=True, + **embedding_kwargs ).cuda() self.processor = AutoProcessor.from_pretrained(model_path) elif "clip" in model_path.lower(): @@ -279,10 +334,11 @@ def start_model_process( elif self.model_type == "reward" or self.model_type == "cross_encoder": from transformers import AutoModelForSequenceClassification + reward_kwargs = common_kwargs.copy() + reward_kwargs["trust_remote_code"] = self.needs_trust_remote_code(model_path) self.model = AutoModelForSequenceClassification.from_pretrained( model_path, - torch_dtype=torch_dtype, - trust_remote_code=self.needs_trust_remote_code(model_path), + **reward_kwargs ).cuda() else: raise Exception(f"Unrecognized model type {self.model_type}") @@ -508,6 +564,9 @@ def __init__( port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None, max_loras_per_batch: int = 4, + quantization: Optional[str] = None, + max_lora_chunk_size: Optional[int] = None, + max_running_requests: Optional[int] = None, attention_backend: Optional[str] = None, prefill_attention_backend: Optional[str] = None, decode_attention_backend: Optional[str] = None, @@ -567,6 +626,9 @@ def __init__( is_embedding=not self.is_generation, lora_paths=lora_paths, max_loras_per_batch=max_loras_per_batch, + quantization=quantization, + max_lora_chunk_size=max_lora_chunk_size, + max_running_requests=max_running_requests, lora_backend=lora_backend, attention_backend=attention_backend, prefill_attention_backend=prefill_attention_backend, diff --git a/test/srt/lora/test_lora_moe.py b/test/srt/lora/test_lora_moe.py index 6cd714ac722a..08bc1ec71637 100644 --- a/test/srt/lora/test_lora_moe.py +++ b/test/srt/lora/test_lora_moe.py @@ -33,7 +33,11 @@ # This will run SGLang and HF tests separately, then compare results python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_qwen15 -v -Manual Testing (run separately): +Manual Testing: + # Run full SRT vs HF comparison (sequential, saves memory) + python test_lora_moe.py --debug + + # Or run individual components separately: # Terminal 1: Run SGLang test and save results python -c " from test.srt.lora.test_lora_moe import TestMoELoRA @@ -76,12 +80,15 @@ """ import json +import logging import multiprocessing as mp import os import pickle +import psutil import random import tempfile import time +import torch import unittest from pathlib import Path @@ -124,6 +131,33 @@ class TestMoELoRA(CustomTestCase): """Test MoE LoRA implementation by comparing against HuggingFace.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Set up detailed logging + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + self.logger = logging.getLogger(__name__) + + def _log_system_info(self): + """Log system and GPU memory information.""" + try: + # CPU memory + memory = psutil.virtual_memory() + self.logger.info(f"System Memory: {memory.available / (1024**3):.2f}GB available / {memory.total / (1024**3):.2f}GB total") + + # GPU memory if available + if torch.cuda.is_available(): + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + allocated = torch.cuda.memory_allocated(0) / (1024**3) + reserved = torch.cuda.memory_reserved(0) / (1024**3) + self.logger.info(f"GPU Memory: {gpu_memory:.2f}GB total, {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") + else: + self.logger.warning("CUDA not available") + except Exception as e: + self.logger.warning(f"Could not get system info: {e}") + def _get_srt_server_args(self): """Get SGLang server arguments from launch.json config.""" return { @@ -144,17 +178,26 @@ def _get_srt_server_args(self): def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): """Run SGLang test and save results to file.""" try: + self.logger.info("=== Starting SGLang MoE LoRA test ===") + self._log_system_info() + base_path = model_case.base adaptor_names = [adaptor.name for adaptor in model_case.adaptors] - print(f"\n========== Running SGLang MoE LoRA test on '{base_path}', dtype={torch_dtype} ==========") + self.logger.info(f"Model: {base_path}") + self.logger.info(f"LoRA adapters: {adaptor_names}") + self.logger.info(f"dtype: {torch_dtype}") server_args = self._get_srt_server_args() server_args["model_path"] = base_path server_args["lora_paths"] = adaptor_names server_args["enable_lora"] = True + self.logger.info(f"Server args: {server_args}") + self.logger.info("Creating SRTRunner...") + # Initialize SGLang runner with launch.json args + self.logger.info("Initializing SRTRunner...") srt_runner = SRTRunner( model_path=base_path, torch_dtype=torch_dtype, @@ -162,6 +205,9 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) tp_size=server_args["tp_size"], lora_paths=adaptor_names, max_loras_per_batch=server_args["max_loras_per_batch"], + quantization=server_args["quantization"], + max_lora_chunk_size=server_args["max_lora_chunk_size"], + max_running_requests=server_args["max_running_requests"], lora_backend=server_args["lora_backend"], disable_cuda_graph=False, disable_radix_cache=server_args["disable_radix_cache"], @@ -169,15 +215,11 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) page_size=server_args["page_size"], mem_fraction_static=server_args["mem_fraction_static"], sleep_on_idle=True, - # Pass additional args via json_model_override_args - json_model_override_args={ - "quantization": server_args["quantization"], - "max_lora_chunk_size": server_args["max_lora_chunk_size"], - "max_running_requests": server_args["max_running_requests"], - } ) + self.logger.info("SRTRunner created successfully") results = {} + self._log_system_info() # Check memory after runner creation # Test with different batch configurations test_configs = [ @@ -185,7 +227,11 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA ] + self.logger.info("Entering SRTRunner context manager (this loads the model)...") with srt_runner: + self.logger.info("SRTRunner context entered - model should be loaded") + self._log_system_info() # Check memory after model loading + for config in test_configs: batch_size = config["batch_size"] lora_paths = config["lora_paths"] @@ -195,18 +241,30 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - print(f"\n--- SRT Testing {config_key} ---") - print(f"Prompts: {prompts}") + self.logger.info(f"=== Testing config: {config_key} ===") + self.logger.info(f"Batch size: {batch_size}") + self.logger.info(f"LoRA paths: {lora_paths}") + self.logger.info(f"Prompts: {prompts}") # Ensure reproducibility + self.logger.info("Ensuring reproducibility...") ensure_reproducibility() # Run SGLang + self.logger.info("Running batch_forward...") srt_outputs = srt_runner.batch_forward( prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths, ) + self.logger.info("batch_forward completed successfully") + + # Print responses + self.logger.info("=== SRT Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): + self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + self.logger.info(f"SRT Output {i+1}: {output}") + self.logger.info("") results[config_key] = { "prompts": prompts, @@ -217,29 +275,56 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) "top_output_logprobs": srt_outputs.top_output_logprobs, } } + self.logger.info(f"Results saved for config: {config_key}") # Save results + self.logger.info(f"Saving results to {result_file}") with open(result_file, 'wb') as f: pickle.dump(results, f) - print(f"SGLang results saved to {result_file}") + self.logger.info(f"SGLang results saved successfully to {result_file}") + + # Force GPU memory cleanup after context manager exit + self.logger.info("Forcing GPU memory cleanup...") + import torch + torch.cuda.empty_cache() + torch.cuda.synchronize() + self._log_system_info() # Log memory after cleanup except Exception as e: - print(f"SGLang test failed: {e}") + self.logger.error(f"SGLang test failed: {e}") + import traceback + self.logger.error(f"Full traceback: {traceback.format_exc()}") with open(result_file, 'wb') as f: - pickle.dump({"error": str(e)}, f) + pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): """Run HuggingFace test and save results to file.""" try: + self.logger.info("=== Starting HF MoE LoRA test ===") + self._log_system_info() + base_path = model_case.base adaptor_names = [adaptor.name for adaptor in model_case.adaptors] - print(f"\n========== Running HF MoE LoRA test on '{base_path}', dtype={torch_dtype} ==========") + self.logger.info(f"Model: {base_path}") + self.logger.info(f"LoRA adapters: {adaptor_names}") + self.logger.info(f"dtype: {torch_dtype}") + + # Get server args for consistency + server_args = self._get_srt_server_args() # Initialize HF runner + self.logger.info("Creating HFRunner...") hf_runner = HFRunner( - base_path, torch_dtype=torch_dtype, model_type="generation" + base_path, + torch_dtype=torch_dtype, + model_type="generation", + trust_remote_code=True, # Match SRTRunner behavior + quantization=server_args["quantization"], # Enable quantization if specified + device_map="auto", # Distribute across available GPUs (like tp_size for SRT) ) + self.logger.info("HFRunner created successfully") + self._log_system_info() # Check memory after runner creation results = {} @@ -249,7 +334,11 @@ def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA ] + self.logger.info("Entering HFRunner context manager (this loads the model)...") with hf_runner: + self.logger.info("HFRunner context entered - model should be loaded") + self._log_system_info() # Check memory after model loading + for config in test_configs: batch_size = config["batch_size"] lora_paths = config["lora_paths"] @@ -259,18 +348,30 @@ def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - print(f"\n--- HF Testing {config_key} ---") - print(f"Prompts: {prompts}") + self.logger.info(f"=== HF Testing config: {config_key} ===") + self.logger.info(f"Batch size: {batch_size}") + self.logger.info(f"LoRA paths: {lora_paths}") + self.logger.info(f"Prompts: {prompts}") # Ensure reproducibility + self.logger.info("Ensuring reproducibility...") ensure_reproducibility() # Run HuggingFace + self.logger.info("Running HF forward...") hf_outputs = hf_runner.forward( prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths, ) + self.logger.info("HF forward completed successfully") + + # Print responses + self.logger.info("=== HF Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): + self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + self.logger.info(f"HF Output {i+1}: {output}") + self.logger.info("") results[config_key] = { "prompts": prompts, @@ -281,59 +382,91 @@ def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): "top_output_logprobs": hf_outputs.top_output_logprobs, } } + self.logger.info(f"HF results saved for config: {config_key}") # Save results + self.logger.info(f"Saving HF results to {result_file}") with open(result_file, 'wb') as f: pickle.dump(results, f) - print(f"HF results saved to {result_file}") + self.logger.info(f"HF results saved successfully to {result_file}") + + # Force GPU memory cleanup after context manager exit + self.logger.info("Forcing GPU memory cleanup...") + import torch + torch.cuda.empty_cache() + torch.cuda.synchronize() + self._log_system_info() # Log memory after cleanup except Exception as e: - print(f"HF test failed: {e}") + self.logger.error(f"HF test failed: {e}") + import traceback + self.logger.error(f"Full traceback: {traceback.format_exc()}") with open(result_file, 'wb') as f: - pickle.dump({"error": str(e)}, f) + pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) def _run_moe_lora_comparison(self, model_case: LoRAModelCase, torch_dtype, max_new_tokens=32): """Run LoRA comparison test by loading saved results from separate sessions.""" + self.logger.info("=== Starting MoE LoRA comparison test ===") + self._log_system_info() + base_path = model_case.base + self.logger.info(f"Model: {base_path}, dtype: {torch_dtype}") # Create temp directory for results with tempfile.TemporaryDirectory() as temp_dir: srt_result_file = os.path.join(temp_dir, 'srt_results.pkl') hf_result_file = os.path.join(temp_dir, 'hf_results.pkl') - print(f"\n========== Testing MoE LoRA on '{base_path}', dtype={torch_dtype} ==========") + self.logger.info(f"Results directory: {temp_dir}") + self.logger.info(f"SRT result file: {srt_result_file}") + self.logger.info(f"HF result file: {hf_result_file}") # Check if results already exist srt_done = os.path.exists(srt_result_file) hf_done = os.path.exists(hf_result_file) + self.logger.info(f"SRT results exist: {srt_done}") + self.logger.info(f"HF results exist: {hf_done}") + if not srt_done: - print("Running SGLang test...") + self.logger.info("Running SGLang test...") self._run_srt_test(srt_result_file, model_case, torch_dtype, max_new_tokens) srt_done = True if not hf_done: - print("Running HuggingFace test...") + self.logger.info("Running HuggingFace test...") self._run_hf_test(hf_result_file, model_case, torch_dtype, max_new_tokens) hf_done = True # Load results + self.logger.info("Loading SRT results...") if srt_done: with open(srt_result_file, 'rb') as f: srt_results = pickle.load(f) if "error" in srt_results: - self.fail(f"SGLang test failed: {srt_results['error']}") + error_msg = f"SGLang test failed: {srt_results['error']}" + if "traceback" in srt_results: + error_msg += f"\nTraceback: {srt_results['traceback']}" + self.fail(error_msg) + self.logger.info("SRT results loaded successfully") + self.logger.info("Loading HF results...") if hf_done: with open(hf_result_file, 'rb') as f: hf_results = pickle.load(f) if "error" in hf_results: - self.fail(f"HF test failed: {hf_results['error']}") + error_msg = f"HF test failed: {hf_results['error']}" + if "traceback" in hf_results: + error_msg += f"\nTraceback: {hf_results['traceback']}" + self.fail(error_msg) + self.logger.info("HF results loaded successfully") # Compare results if srt_done and hf_done: + self.logger.info("Starting result comparison...") for config_key in srt_results.keys(): if config_key in hf_results: + self.logger.info(f"Comparing config: {config_key}") srt_data = srt_results[config_key] hf_data = hf_results[config_key] @@ -341,6 +474,10 @@ def _run_moe_lora_comparison(self, model_case: LoRAModelCase, torch_dtype, max_n srt_data["srt_outputs"], hf_data["hf_outputs"], model_case, srt_data["prompts"], srt_data["lora_paths"] ) + else: + self.logger.warning(f"No HF results for config: {config_key}") + + self.logger.info("Comparison completed successfully") def _compare_outputs(self, srt_outputs, hf_outputs, model_case, prompts, lora_paths): """Compare SGLang and HF outputs.""" @@ -406,18 +543,26 @@ def _compare_outputs(self, srt_outputs, hf_outputs, model_case, prompts, lora_pa def test_moe_lora_qwen15(self): """Test LoRA on Qwen1.5-MoE-A2.7B.""" + self.logger.info("=== Starting test_moe_lora_qwen15 ===") + if is_in_ci(): + self.logger.info("Skipping MoE LoRA test in CI environment") self.skipTest("Skipping MoE LoRA test in CI environment") model_case = MOE_LORA_TEST_CASES[0] + self.logger.info(f"Using model case: {model_case.base}") # Test with different dtypes import torch for torch_dtype in [torch.float16, torch.bfloat16]: + self.logger.info(f"Testing dtype: {torch_dtype}") with self.subTest(dtype=torch_dtype): try: self._run_moe_lora_comparison(model_case, torch_dtype) except Exception as e: + self.logger.error(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") + import traceback + self.logger.error(f"Traceback: {traceback.format_exc()}") self.fail(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") def test_moe_lora_basic_functionality(self): @@ -469,10 +614,220 @@ def test_moe_lora_basic_functionality(self): print(f"Dispatch successful: {len(token_ids)} dispatched tokens to experts {unique_experts.tolist()}") -if __name__ == "__main__": +def debug_full_comparison(): + """Debug helper to run full SRT vs HF comparison.""" + import torch + import tempfile + import os + + # Set up logging for debugging + logging.basicConfig( + level=logging.INFO, # Less verbose for debug script + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + test = TestMoELoRA() + model_case = MOE_LORA_TEST_CASES[0] + + print("=" * 80) + print("DEBUG: Running Full SRT vs HF Comparison") + print("=" * 80) + + srt_results = {} + hf_results = {} + + # Test configurations + test_configs = [ + {"batch_size": 1, "lora_paths": [model_case.adaptors[0].name]}, # Single request, single LoRA + {"batch_size": 2, "lora_paths": [model_case.adaptors[0].name, model_case.adaptors[0].name]}, # Multiple requests, same LoRA + ] + try: - mp.set_start_method("spawn") - except RuntimeError: - pass + # Phase 1: Run SRT tests + print("\n" + "="*50) + print("PHASE 1: Running SGLang (SRT) Tests") + print("="*50) + + server_args = test._get_srt_server_args() + print(f"Server args: quantization={server_args['quantization']}, tp_size={server_args['tp_size']}, mem_fraction_static={server_args['mem_fraction_static']}") + + srt_runner = SRTRunner( + model_path=model_case.base, + torch_dtype=torch.float16, + model_type="generation", + tp_size=server_args["tp_size"], + lora_paths=[adaptor.name for adaptor in model_case.adaptors], + max_loras_per_batch=server_args["max_loras_per_batch"], + quantization=server_args["quantization"], + max_lora_chunk_size=server_args["max_lora_chunk_size"], + max_running_requests=server_args["max_running_requests"], + lora_backend=server_args["lora_backend"], + disable_cuda_graph=False, + disable_radix_cache=server_args["disable_radix_cache"], + max_total_tokens=server_args["max_total_tokens"], + page_size=server_args["page_size"], + mem_fraction_static=server_args["mem_fraction_static"], + sleep_on_idle=True, + ) + print("✓ SRT runner created successfully") + + with srt_runner: + print("✓ SRT model loaded successfully") + test._log_system_info() + + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + prompts = TEST_MOE_PROMPTS[:batch_size] + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + print(f"\n--- SRT Testing {config_key} ---") + print(f"Prompts: {prompts}") + + # Ensure reproducibility + ensure_reproducibility() + + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=32, + lora_paths=lora_paths, + ) + print("✓ SRT batch_forward completed") + + # Print SRT responses + print("=== SRT Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): + print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + print(f"SRT Output {i+1}: {output}") + print("") + + srt_results[config_key] = { + "prompts": prompts, + "lora_paths": lora_paths, + "srt_outputs": srt_outputs + } + + print("✓ All SRT tests completed successfully") + + # Clear GPU memory before HF test + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("✓ GPU cache cleared before HF test") + + # Small delay to ensure memory is fully released + import time + time.sleep(2) + test._log_system_info() + + # Phase 2: Run HF tests + print("\n" + "="*50) + print("PHASE 2: Running HuggingFace (HF) Tests") + print("="*50) + + server_args = test._get_srt_server_args() + hf_runner = HFRunner( + model_case.base, + torch_dtype=torch.float16, + model_type="generation", + trust_remote_code=True, # Match SRTRunner behavior + quantization=server_args["quantization"], # Enable quantization if specified + device_map="auto", # Distribute across available GPUs + ) + print("✓ HF runner created successfully") + + with hf_runner: + print("✓ HF model loaded successfully") + test._log_system_info() + + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + prompts = TEST_MOE_PROMPTS[:batch_size] + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + print(f"\n--- HF Testing {config_key} ---") + print(f"Prompts: {prompts}") + + # Ensure reproducibility + ensure_reproducibility() + + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=32, + lora_paths=lora_paths, + ) + print("✓ HF forward completed") + + # Print HF responses + print("=== HF Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): + print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + print(f"HF Output {i+1}: {output}") + print("") + + hf_results[config_key] = { + "prompts": prompts, + "lora_paths": lora_paths, + "hf_outputs": hf_outputs + } + + print("✓ All HF tests completed successfully") + + # Force final GPU memory cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("✓ Final GPU memory cleanup completed") + test._log_system_info() + + # Phase 3: Compare results + print("\n" + "="*50) + print("PHASE 3: Comparing SRT vs HF Outputs") + print("="*50) + + for config_key in srt_results.keys(): + if config_key in hf_results: + srt_data = srt_results[config_key] + hf_data = hf_results[config_key] + + print(f"\n{'='*30} Comparing {config_key} {'='*30}") + try: + test._compare_outputs( + srt_data["srt_outputs"], hf_data["hf_outputs"], + model_case, srt_data["prompts"], srt_data["lora_paths"] + ) + print(f"✓ Comparison passed for {config_key}") + except AssertionError as e: + print(f"✗ Comparison failed for {config_key}: {e}") + raise + else: + print(f"✗ No HF results for config: {config_key}") + + print("\n" + "="*80) + print("🎉 ALL COMPARISONS PASSED! SRT and HF outputs match!") + print("="*80) + + except Exception as e: + print(f"\n💥 DEBUG SCRIPT FAILED: {e}") + import traceback + traceback.print_exc() + return False + + return True + + +if __name__ == "__main__": + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "--debug": + # Run debug tests instead of unittest + success = debug_full_comparison() + sys.exit(0 if success else 1) + else: + try: + mp.set_start_method("spawn") + except RuntimeError: + pass - unittest.main() + unittest.main() From 7fa7ddd1f8322d3670d817739f546cd6d85ff8ce Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 28 Nov 2025 16:03:34 +0000 Subject: [PATCH 028/150] simplify test --- test/srt/lora/test_lora_moe.py | 720 +++++---------------------------- 1 file changed, 112 insertions(+), 608 deletions(-) diff --git a/test/srt/lora/test_lora_moe.py b/test/srt/lora/test_lora_moe.py index 08bc1ec71637..e59cdb26717e 100644 --- a/test/srt/lora/test_lora_moe.py +++ b/test/srt/lora/test_lora_moe.py @@ -15,102 +15,32 @@ """ Test MoE LoRA implementation by comparing against HuggingFace. -This test file verifies that SGLang's MoE LoRA implementation produces the same -outputs as HuggingFace's PEFT library for MoE models. - -WORKFLOW: -1. Run SGLang test in first session (loads SGLang model, saves results to temp file) -2. Run HF test in second session (loads HF model, saves results to temp file) -3. Compare saved results from both sessions - -This avoids loading both models simultaneously, saving GPU memory. - Usage: - # Run basic functionality test (no HF comparison) - python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_basic_functionality -v - - # Run full HF comparison test (requires model and LoRA adapter) - # This will run SGLang and HF tests separately, then compare results - python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_qwen15 -v - -Manual Testing: - # Run full SRT vs HF comparison (sequential, saves memory) + python -m pytest test/srt/lora/test_lora_moe.py -v python test_lora_moe.py --debug - - # Or run individual components separately: - # Terminal 1: Run SGLang test and save results - python -c " - from test.srt.lora.test_lora_moe import TestMoELoRA - import torch - import tempfile - import os - - test = TestMoELoRA() - model_case = test.MOE_LORA_TEST_CASES[0] - - with tempfile.TemporaryDirectory() as temp_dir: - result_file = os.path.join(temp_dir, 'srt_results.pkl') - test._run_srt_test(result_file, model_case, torch.float16) - print(f'SGLang results saved to {result_file}') - input('Press Enter after copying result file...') - " - - # Terminal 2: Run HF test and save results - python -c " - from test.srt.lora.test_lora_moe import TestMoELoRA - import torch - import tempfile - import os - - test = TestMoELoRA() - model_case = test.MOE_LORA_TEST_CASES[0] - - with tempfile.TemporaryDirectory() as temp_dir: - result_file = os.path.join(temp_dir, 'hf_results.pkl') - test._run_hf_test(result_file, model_case, torch.float16) - print(f'HF results saved to {result_file}') - input('Press Enter after copying result file...') - " - -Requirements: - - Qwen/Qwen1.5-MoE-A2.7B model - - sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest LoRA adapter - - Sufficient GPU memory (tests run in separate sessions) - - Uses same memory configuration as launch.json for compatibility """ -import json import logging import multiprocessing as mp -import os import pickle -import psutil -import random import tempfile -import time +import os import torch import unittest -from pathlib import Path from utils import LoRAModelCase, LoRAAdaptor, ensure_reproducibility - from sglang.test.runners import HFRunner, SRTRunner from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci -# Test prompts for MoE LoRA comparison -TEST_MOE_PROMPTS = [ +TEST_PROMPTS = [ "The capital of France is Paris. The capital of", "Explain what mixture of experts means in machine learning.", - "Write a short poem about artificial intelligence and large language models.", - "What are the benefits of using MoE architectures in transformers?", ] -# MoE model test cases with LoRA adapters MOE_LORA_TEST_CASES = [ LoRAModelCase( base="Qwen/Qwen1.5-MoE-A2.7B", adaptors=[ - # Use a real LoRA adapter path - replace with actual path when testing LoRAAdaptor( name="sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest", prefill_tolerance=1e-1, @@ -124,7 +54,6 @@ rouge_l_tolerance=1.0, max_loras_per_batch=1, ), - # Add more MoE models here when available ] @@ -133,40 +62,16 @@ class TestMoELoRA(CustomTestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Set up detailed logging - logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') self.logger = logging.getLogger(__name__) - def _log_system_info(self): - """Log system and GPU memory information.""" - try: - # CPU memory - memory = psutil.virtual_memory() - self.logger.info(f"System Memory: {memory.available / (1024**3):.2f}GB available / {memory.total / (1024**3):.2f}GB total") - - # GPU memory if available - if torch.cuda.is_available(): - gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) - allocated = torch.cuda.memory_allocated(0) / (1024**3) - reserved = torch.cuda.memory_reserved(0) / (1024**3) - self.logger.info(f"GPU Memory: {gpu_memory:.2f}GB total, {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") - else: - self.logger.warning("CUDA not available") - except Exception as e: - self.logger.warning(f"Could not get system info: {e}") - - def _get_srt_server_args(self): - """Get SGLang server arguments from launch.json config.""" + def _get_server_args(self): + """Get SGLang server arguments.""" return { "quantization": "fp8", "disable_radix_cache": True, "lora_backend": "csgmv", "max_lora_chunk_size": 16, - "port": 30000, - "host": "127.0.0.1", "max_loras_per_batch": 1, "tp_size": 2, "max_total_tokens": 128, @@ -176,30 +81,13 @@ def _get_srt_server_args(self): } def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): - """Run SGLang test and save results to file.""" + """Run SGLang test and save results.""" try: - self.logger.info("=== Starting SGLang MoE LoRA test ===") - self._log_system_info() - - base_path = model_case.base - adaptor_names = [adaptor.name for adaptor in model_case.adaptors] + server_args = self._get_server_args() + adaptor_names = [a.name for a in model_case.adaptors] - self.logger.info(f"Model: {base_path}") - self.logger.info(f"LoRA adapters: {adaptor_names}") - self.logger.info(f"dtype: {torch_dtype}") - - server_args = self._get_srt_server_args() - server_args["model_path"] = base_path - server_args["lora_paths"] = adaptor_names - server_args["enable_lora"] = True - - self.logger.info(f"Server args: {server_args}") - self.logger.info("Creating SRTRunner...") - - # Initialize SGLang runner with launch.json args - self.logger.info("Initializing SRTRunner...") srt_runner = SRTRunner( - model_path=base_path, + model_path=model_case.base, torch_dtype=torch_dtype, model_type="generation", tp_size=server_args["tp_size"], @@ -209,619 +97,236 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) max_lora_chunk_size=server_args["max_lora_chunk_size"], max_running_requests=server_args["max_running_requests"], lora_backend=server_args["lora_backend"], - disable_cuda_graph=False, disable_radix_cache=server_args["disable_radix_cache"], max_total_tokens=server_args["max_total_tokens"], page_size=server_args["page_size"], mem_fraction_static=server_args["mem_fraction_static"], - sleep_on_idle=True, ) - self.logger.info("SRTRunner created successfully") results = {} - self._log_system_info() # Check memory after runner creation - - # Test with different batch configurations - test_configs = [ - {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA - {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA - ] - - self.logger.info("Entering SRTRunner context manager (this loads the model)...") with srt_runner: - self.logger.info("SRTRunner context entered - model should be loaded") - self._log_system_info() # Check memory after model loading - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - - # Use fixed prompts for reproducibility - prompts = TEST_MOE_PROMPTS[:batch_size] - - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - self.logger.info(f"=== Testing config: {config_key} ===") - self.logger.info(f"Batch size: {batch_size}") - self.logger.info(f"LoRA paths: {lora_paths}") - self.logger.info(f"Prompts: {prompts}") - - # Ensure reproducibility - self.logger.info("Ensuring reproducibility...") + for batch_size in [1, 2]: + prompts = TEST_PROMPTS[:batch_size] + lora_paths = [adaptor_names[0]] * batch_size ensure_reproducibility() - # Run SGLang - self.logger.info("Running batch_forward...") - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - self.logger.info("batch_forward completed successfully") - - # Print responses - self.logger.info("=== SRT Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): - self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - self.logger.info(f"SRT Output {i+1}: {output}") - self.logger.info("") - - results[config_key] = { + outputs = srt_runner.batch_forward(prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths) + results[f"batch_{batch_size}"] = { "prompts": prompts, "lora_paths": lora_paths, - "srt_outputs": { - "output_strs": srt_outputs.output_strs, - "top_input_logprobs": srt_outputs.top_input_logprobs, - "top_output_logprobs": srt_outputs.top_output_logprobs, + "outputs": { + "output_strs": outputs.output_strs, + "top_input_logprobs": outputs.top_input_logprobs, + "top_output_logprobs": outputs.top_output_logprobs, } } - self.logger.info(f"Results saved for config: {config_key}") - # Save results - self.logger.info(f"Saving results to {result_file}") with open(result_file, 'wb') as f: pickle.dump(results, f) - self.logger.info(f"SGLang results saved successfully to {result_file}") - # Force GPU memory cleanup after context manager exit - self.logger.info("Forcing GPU memory cleanup...") - import torch torch.cuda.empty_cache() - torch.cuda.synchronize() - self._log_system_info() # Log memory after cleanup except Exception as e: - self.logger.error(f"SGLang test failed: {e}") import traceback - self.logger.error(f"Full traceback: {traceback.format_exc()}") with open(result_file, 'wb') as f: pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): - """Run HuggingFace test and save results to file.""" + """Run HuggingFace test and save results.""" try: - self.logger.info("=== Starting HF MoE LoRA test ===") - self._log_system_info() - - base_path = model_case.base - adaptor_names = [adaptor.name for adaptor in model_case.adaptors] - - self.logger.info(f"Model: {base_path}") - self.logger.info(f"LoRA adapters: {adaptor_names}") - self.logger.info(f"dtype: {torch_dtype}") + server_args = self._get_server_args() + adaptor_names = [a.name for a in model_case.adaptors] - # Get server args for consistency - server_args = self._get_srt_server_args() - - # Initialize HF runner - self.logger.info("Creating HFRunner...") hf_runner = HFRunner( - base_path, + model_case.base, torch_dtype=torch_dtype, model_type="generation", - trust_remote_code=True, # Match SRTRunner behavior - quantization=server_args["quantization"], # Enable quantization if specified - device_map="auto", # Distribute across available GPUs (like tp_size for SRT) + trust_remote_code=True, + quantization=server_args["quantization"], + device_map="auto", ) - self.logger.info("HFRunner created successfully") - self._log_system_info() # Check memory after runner creation results = {} - - # Test with different batch configurations - test_configs = [ - {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA - {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA - ] - - self.logger.info("Entering HFRunner context manager (this loads the model)...") with hf_runner: - self.logger.info("HFRunner context entered - model should be loaded") - self._log_system_info() # Check memory after model loading - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - - # Use fixed prompts for reproducibility - prompts = TEST_MOE_PROMPTS[:batch_size] - - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - self.logger.info(f"=== HF Testing config: {config_key} ===") - self.logger.info(f"Batch size: {batch_size}") - self.logger.info(f"LoRA paths: {lora_paths}") - self.logger.info(f"Prompts: {prompts}") - - # Ensure reproducibility - self.logger.info("Ensuring reproducibility...") + for batch_size in [1, 2]: + prompts = TEST_PROMPTS[:batch_size] + lora_paths = [adaptor_names[0]] * batch_size ensure_reproducibility() - # Run HuggingFace - self.logger.info("Running HF forward...") - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - self.logger.info("HF forward completed successfully") - - # Print responses - self.logger.info("=== HF Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): - self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - self.logger.info(f"HF Output {i+1}: {output}") - self.logger.info("") - - results[config_key] = { + outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths) + results[f"batch_{batch_size}"] = { "prompts": prompts, "lora_paths": lora_paths, - "hf_outputs": { - "output_strs": hf_outputs.output_strs, - "top_input_logprobs": hf_outputs.top_input_logprobs, - "top_output_logprobs": hf_outputs.top_output_logprobs, + "outputs": { + "output_strs": outputs.output_strs, + "top_input_logprobs": outputs.top_input_logprobs, + "top_output_logprobs": outputs.top_output_logprobs, } } - self.logger.info(f"HF results saved for config: {config_key}") - # Save results - self.logger.info(f"Saving HF results to {result_file}") with open(result_file, 'wb') as f: pickle.dump(results, f) - self.logger.info(f"HF results saved successfully to {result_file}") - # Force GPU memory cleanup after context manager exit - self.logger.info("Forcing GPU memory cleanup...") - import torch torch.cuda.empty_cache() - torch.cuda.synchronize() - self._log_system_info() # Log memory after cleanup except Exception as e: - self.logger.error(f"HF test failed: {e}") import traceback - self.logger.error(f"Full traceback: {traceback.format_exc()}") with open(result_file, 'wb') as f: pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) - def _run_moe_lora_comparison(self, model_case: LoRAModelCase, torch_dtype, max_new_tokens=32): - """Run LoRA comparison test by loading saved results from separate sessions.""" - self.logger.info("=== Starting MoE LoRA comparison test ===") - self._log_system_info() - - base_path = model_case.base - self.logger.info(f"Model: {base_path}, dtype: {torch_dtype}") - - # Create temp directory for results - with tempfile.TemporaryDirectory() as temp_dir: - srt_result_file = os.path.join(temp_dir, 'srt_results.pkl') - hf_result_file = os.path.join(temp_dir, 'hf_results.pkl') - - self.logger.info(f"Results directory: {temp_dir}") - self.logger.info(f"SRT result file: {srt_result_file}") - self.logger.info(f"HF result file: {hf_result_file}") - - # Check if results already exist - srt_done = os.path.exists(srt_result_file) - hf_done = os.path.exists(hf_result_file) - - self.logger.info(f"SRT results exist: {srt_done}") - self.logger.info(f"HF results exist: {hf_done}") - - if not srt_done: - self.logger.info("Running SGLang test...") - self._run_srt_test(srt_result_file, model_case, torch_dtype, max_new_tokens) - srt_done = True - - if not hf_done: - self.logger.info("Running HuggingFace test...") - self._run_hf_test(hf_result_file, model_case, torch_dtype, max_new_tokens) - hf_done = True - - # Load results - self.logger.info("Loading SRT results...") - if srt_done: - with open(srt_result_file, 'rb') as f: - srt_results = pickle.load(f) - if "error" in srt_results: - error_msg = f"SGLang test failed: {srt_results['error']}" - if "traceback" in srt_results: - error_msg += f"\nTraceback: {srt_results['traceback']}" - self.fail(error_msg) - self.logger.info("SRT results loaded successfully") - - self.logger.info("Loading HF results...") - if hf_done: - with open(hf_result_file, 'rb') as f: - hf_results = pickle.load(f) - if "error" in hf_results: - error_msg = f"HF test failed: {hf_results['error']}" - if "traceback" in hf_results: - error_msg += f"\nTraceback: {hf_results['traceback']}" - self.fail(error_msg) - self.logger.info("HF results loaded successfully") - - # Compare results - if srt_done and hf_done: - self.logger.info("Starting result comparison...") - for config_key in srt_results.keys(): - if config_key in hf_results: - self.logger.info(f"Comparing config: {config_key}") - srt_data = srt_results[config_key] - hf_data = hf_results[config_key] - - self._compare_outputs( - srt_data["srt_outputs"], hf_data["hf_outputs"], - model_case, srt_data["prompts"], srt_data["lora_paths"] - ) - else: - self.logger.warning(f"No HF results for config: {config_key}") - - self.logger.info("Comparison completed successfully") - def _compare_outputs(self, srt_outputs, hf_outputs, model_case, prompts, lora_paths): """Compare SGLang and HF outputs.""" for i, (prompt, lora_path) in enumerate(zip(prompts, lora_paths)): - print(f"\nRequest {i}: lora_path='{lora_path}'") - print(f"Prompt: {prompt[:50]}{'...' if len(prompt) > 50 else ''}") - - # Compare output strings - srt_output_str = srt_outputs.output_strs[i].strip() - hf_output_str = hf_outputs.output_strs[i].strip() - - print(f"SRT output: {srt_output_str[:100]}{'...' if len(srt_output_str) > 100 else ''}") - print(f"HF output: {hf_output_str[:100]}{'...' if len(hf_output_str) > 100 else ''}") - - # Calculate ROUGE-L similarity - rouge_l = calculate_rouge_l(srt_output_str, hf_output_str) - print(f"ROUGE-L similarity: {rouge_l:.4f}") - - # Check ROUGE-L tolerance - self.assertGreaterEqual( - rouge_l, - model_case.rouge_l_tolerance, - f"ROUGE-L similarity {rouge_l:.4f} below tolerance {model_case.rouge_l_tolerance} " - f"for request {i} with lora_path='{lora_path}'" - ) + srt_str = srt_outputs["output_strs"][i].strip() + hf_str = hf_outputs["output_strs"][i].strip() - # Compare logprobs if available - if hasattr(srt_outputs, 'top_input_logprobs') and hasattr(hf_outputs, 'top_input_logprobs'): - if srt_outputs.top_input_logprobs[i] is not None and hf_outputs.top_input_logprobs[i] is not None: - import torch - srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i]) - hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i]) - - max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill)) - print(f"Max prefill logprob diff: {max_prefill_diff:.6f}") - - # Check prefill tolerance - prefill_tol = model_case.prefill_tolerance - self.assertLessEqual( - max_prefill_diff, - prefill_tol, - f"Prefill logprob diff {max_prefill_diff:.6f} exceeds tolerance {prefill_tol} " - f"for request {i} with lora_path='{lora_path}'" - ) - - if hasattr(srt_outputs, 'top_output_logprobs') and hasattr(hf_outputs, 'top_output_logprobs'): - if srt_outputs.top_output_logprobs[i] is not None and hf_outputs.top_output_logprobs[i] is not None: - import torch - srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i]) - hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i]) - - max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode)) - print(f"Max decode logprob diff: {max_decode_diff:.6f}") - - # Check decode tolerance - decode_tol = model_case.decode_tolerance - self.assertLessEqual( - max_decode_diff, - decode_tol, - f"Decode logprob diff {max_decode_diff:.6f} exceeds tolerance {decode_tol} " - f"for request {i} with lora_path='{lora_path}'" - ) + rouge_l = calculate_rouge_l(srt_str, hf_str) + self.assertGreaterEqual(rouge_l, model_case.rouge_l_tolerance, + f"ROUGE-L {rouge_l:.4f} below tolerance for request {i}") def test_moe_lora_qwen15(self): """Test LoRA on Qwen1.5-MoE-A2.7B.""" - self.logger.info("=== Starting test_moe_lora_qwen15 ===") - if is_in_ci(): - self.logger.info("Skipping MoE LoRA test in CI environment") - self.skipTest("Skipping MoE LoRA test in CI environment") + self.skipTest("Skipping MoE LoRA test in CI") model_case = MOE_LORA_TEST_CASES[0] - self.logger.info(f"Using model case: {model_case.base}") - - # Test with different dtypes - import torch - for torch_dtype in [torch.float16, torch.bfloat16]: - self.logger.info(f"Testing dtype: {torch_dtype}") - with self.subTest(dtype=torch_dtype): - try: - self._run_moe_lora_comparison(model_case, torch_dtype) - except Exception as e: - self.logger.error(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") - import traceback - self.logger.error(f"Traceback: {traceback.format_exc()}") - self.fail(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") + + with tempfile.TemporaryDirectory() as temp_dir: + srt_file = os.path.join(temp_dir, 'srt.pkl') + hf_file = os.path.join(temp_dir, 'hf.pkl') + + self._run_srt_test(srt_file, model_case, torch.float16) + self._run_hf_test(hf_file, model_case, torch.float16) + + with open(srt_file, 'rb') as f: + srt_results = pickle.load(f) + with open(hf_file, 'rb') as f: + hf_results = pickle.load(f) + + if "error" in srt_results: + self.fail(f"SRT failed: {srt_results['error']}") + if "error" in hf_results: + self.fail(f"HF failed: {hf_results['error']}") + + for key in srt_results: + if key in hf_results: + self._compare_outputs( + srt_results[key]["outputs"], + hf_results[key]["outputs"], + model_case, + srt_results[key]["prompts"], + srt_results[key]["lora_paths"] + ) def test_moe_lora_basic_functionality(self): """Basic functionality test for MoE LoRA dispatch.""" - # This test focuses on the core dispatch logic without full HF comparison - # Useful for debugging the MoE LoRA implementation - - import torch from sglang.srt.lora.moe_dispatch import moe_dispatch - # Create test data - num_tokens = 4 - top_k = 2 - num_experts = 8 - - # Mock top-k routing results - topk_ids = torch.tensor([ - [0, 1], # token 0 routes to experts 0, 1 - [2, 3], # token 1 routes to experts 2, 3 - [1, 4], # token 2 routes to experts 1, 4 - [5, 6], # token 3 routes to experts 5, 6 - ], dtype=torch.int32) - + topk_ids = torch.tensor([[0, 1], [2, 3], [1, 4], [5, 6]], dtype=torch.int32) topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) + lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) - # Mock LoRA indices (one per token) - lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) # tokens 0,1 use lora 0; tokens 2,3 use lora 1 - - # Run dispatch token_ids, expert_ids, weights = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, lora_indices=lora_indices, - num_experts=num_experts, + num_experts=8, num_loras=2, ) - # Verify results - # Should have 4 tokens * 2 experts each = 8 dispatched entries self.assertEqual(len(token_ids), 8) self.assertEqual(len(expert_ids), 8) self.assertEqual(len(weights), 8) - # Check that tokens are grouped by expert (not by LoRA) - # All tokens going to expert 0 should come first, then expert 1, etc. - unique_experts, expert_counts = torch.unique_consecutive(expert_ids, return_counts=True) - self.assertTrue(torch.all(expert_counts >= 1)) # Each expert should have at least one token - - print(f"Dispatch successful: {len(token_ids)} dispatched tokens to experts {unique_experts.tolist()}") - def debug_full_comparison(): """Debug helper to run full SRT vs HF comparison.""" - import torch - import tempfile - import os - - # Set up logging for debugging - logging.basicConfig( - level=logging.INFO, # Less verbose for debug script - format='%(asctime)s - %(levelname)s - %(message)s' - ) - test = TestMoELoRA() model_case = MOE_LORA_TEST_CASES[0] + server_args = test._get_server_args() + adaptor_names = [a.name for a in model_case.adaptors] - print("=" * 80) - print("DEBUG: Running Full SRT vs HF Comparison") - print("=" * 80) - - srt_results = {} - hf_results = {} - - # Test configurations - test_configs = [ - {"batch_size": 1, "lora_paths": [model_case.adaptors[0].name]}, # Single request, single LoRA - {"batch_size": 2, "lora_paths": [model_case.adaptors[0].name, model_case.adaptors[0].name]}, # Multiple requests, same LoRA - ] + srt_results, hf_results = {}, {} try: - # Phase 1: Run SRT tests - print("\n" + "="*50) - print("PHASE 1: Running SGLang (SRT) Tests") - print("="*50) - - server_args = test._get_srt_server_args() - print(f"Server args: quantization={server_args['quantization']}, tp_size={server_args['tp_size']}, mem_fraction_static={server_args['mem_fraction_static']}") - + # SRT tests + print("Running SGLang tests...") srt_runner = SRTRunner( model_path=model_case.base, torch_dtype=torch.float16, model_type="generation", tp_size=server_args["tp_size"], - lora_paths=[adaptor.name for adaptor in model_case.adaptors], + lora_paths=adaptor_names, max_loras_per_batch=server_args["max_loras_per_batch"], quantization=server_args["quantization"], max_lora_chunk_size=server_args["max_lora_chunk_size"], max_running_requests=server_args["max_running_requests"], lora_backend=server_args["lora_backend"], - disable_cuda_graph=False, disable_radix_cache=server_args["disable_radix_cache"], max_total_tokens=server_args["max_total_tokens"], page_size=server_args["page_size"], mem_fraction_static=server_args["mem_fraction_static"], - sleep_on_idle=True, ) - print("✓ SRT runner created successfully") with srt_runner: - print("✓ SRT model loaded successfully") - test._log_system_info() - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - prompts = TEST_MOE_PROMPTS[:batch_size] - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - print(f"\n--- SRT Testing {config_key} ---") - print(f"Prompts: {prompts}") - - # Ensure reproducibility + for batch_size in [1, 2]: + prompts = TEST_PROMPTS[:batch_size] + lora_paths = [adaptor_names[0]] * batch_size ensure_reproducibility() + outputs = srt_runner.batch_forward(prompts, max_new_tokens=32, lora_paths=lora_paths) + srt_results[f"batch_{batch_size}"] = {"prompts": prompts, "lora_paths": lora_paths, "outputs": outputs} + for i, out in enumerate(outputs.output_strs): + print(f"SRT [{batch_size}] {i}: {out}") - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=32, - lora_paths=lora_paths, - ) - print("✓ SRT batch_forward completed") - - # Print SRT responses - print("=== SRT Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): - print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - print(f"SRT Output {i+1}: {output}") - print("") - - srt_results[config_key] = { - "prompts": prompts, - "lora_paths": lora_paths, - "srt_outputs": srt_outputs - } - - print("✓ All SRT tests completed successfully") - - # Clear GPU memory before HF test - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - print("✓ GPU cache cleared before HF test") - - # Small delay to ensure memory is fully released - import time - time.sleep(2) - test._log_system_info() - - # Phase 2: Run HF tests - print("\n" + "="*50) - print("PHASE 2: Running HuggingFace (HF) Tests") - print("="*50) + torch.cuda.empty_cache() - server_args = test._get_srt_server_args() + # HF tests + print("\nRunning HuggingFace tests...") hf_runner = HFRunner( model_case.base, torch_dtype=torch.float16, model_type="generation", - trust_remote_code=True, # Match SRTRunner behavior - quantization=server_args["quantization"], # Enable quantization if specified - device_map="auto", # Distribute across available GPUs + trust_remote_code=True, + quantization=server_args["quantization"], + device_map="auto", ) - print("✓ HF runner created successfully") with hf_runner: - print("✓ HF model loaded successfully") - test._log_system_info() - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - prompts = TEST_MOE_PROMPTS[:batch_size] - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - print(f"\n--- HF Testing {config_key} ---") - print(f"Prompts: {prompts}") - - # Ensure reproducibility + for batch_size in [1, 2]: + prompts = TEST_PROMPTS[:batch_size] + lora_paths = [adaptor_names[0]] * batch_size ensure_reproducibility() - - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=32, - lora_paths=lora_paths, + outputs = hf_runner.forward(prompts, max_new_tokens=32, lora_paths=lora_paths) + hf_results[f"batch_{batch_size}"] = {"prompts": prompts, "lora_paths": lora_paths, "outputs": outputs} + for i, out in enumerate(outputs.output_strs): + print(f"HF [{batch_size}] {i}: {out}") + + torch.cuda.empty_cache() + + # Compare + print("\nComparing outputs...") + for key in srt_results: + srt_data, hf_data = srt_results[key], hf_results[key] + for i in range(len(srt_data["prompts"])): + rouge_l = calculate_rouge_l( + srt_data["outputs"].output_strs[i], + hf_data["outputs"].output_strs[i] ) - print("✓ HF forward completed") - - # Print HF responses - print("=== HF Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): - print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - print(f"HF Output {i+1}: {output}") - print("") + print(f"{key} request {i}: ROUGE-L = {rouge_l:.4f}") - hf_results[config_key] = { - "prompts": prompts, - "lora_paths": lora_paths, - "hf_outputs": hf_outputs - } - - print("✓ All HF tests completed successfully") - - # Force final GPU memory cleanup - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - print("✓ Final GPU memory cleanup completed") - test._log_system_info() - - # Phase 3: Compare results - print("\n" + "="*50) - print("PHASE 3: Comparing SRT vs HF Outputs") - print("="*50) - - for config_key in srt_results.keys(): - if config_key in hf_results: - srt_data = srt_results[config_key] - hf_data = hf_results[config_key] - - print(f"\n{'='*30} Comparing {config_key} {'='*30}") - try: - test._compare_outputs( - srt_data["srt_outputs"], hf_data["hf_outputs"], - model_case, srt_data["prompts"], srt_data["lora_paths"] - ) - print(f"✓ Comparison passed for {config_key}") - except AssertionError as e: - print(f"✗ Comparison failed for {config_key}: {e}") - raise - else: - print(f"✗ No HF results for config: {config_key}") - - print("\n" + "="*80) - print("🎉 ALL COMPARISONS PASSED! SRT and HF outputs match!") - print("="*80) + print("\nAll comparisons completed!") + return True except Exception as e: - print(f"\n💥 DEBUG SCRIPT FAILED: {e}") import traceback + print(f"Failed: {e}") traceback.print_exc() return False - return True - if __name__ == "__main__": import sys - if len(sys.argv) > 1 and sys.argv[1] == "--debug": - # Run debug tests instead of unittest success = debug_full_comparison() sys.exit(0 if success else 1) else: @@ -829,5 +334,4 @@ def debug_full_comparison(): mp.set_start_method("spawn") except RuntimeError: pass - unittest.main() From 367612aef1c722f5b5a51e106b827f8becc9106c Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 29 Nov 2025 00:55:42 +0000 Subject: [PATCH 029/150] fix lora id issue --- python/sglang/srt/layers/moe/lora_moe.py | 5 +- python/sglang/srt/lora/lora_manager.py | 2 +- python/sglang/srt/lora/moe_dispatch.py | 15 +- test/srt/lora/test_lora_moe.py | 720 +++++++++++++++++++---- 4 files changed, 619 insertions(+), 123 deletions(-) diff --git a/python/sglang/srt/layers/moe/lora_moe.py b/python/sglang/srt/layers/moe/lora_moe.py index f4bf7886c7b9..2ea43c6380c6 100644 --- a/python/sglang/srt/layers/moe/lora_moe.py +++ b/python/sglang/srt/layers/moe/lora_moe.py @@ -104,7 +104,7 @@ def _compute_lora_delta( num_loras = self.lora_a_weights.shape[0] # Dispatch tokens to experts - token_ids, expert_ids, _ = moe_dispatch( + token_ids, expert_ids, _, lora_ids = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, lora_indices=lora_indices, @@ -112,9 +112,6 @@ def _compute_lora_delta( num_loras=num_loras, ) - # Get LoRA IDs for dispatched tokens - lora_ids = lora_indices[token_ids] - # Compute per-expert LoRA forward (adds to base_output in-place) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index c4818d0a6241..2a8384459616 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -285,7 +285,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): # Populate per-token LoRA indices from segment information batch_info = self.lora_backend.batch_info - num_tokens = forward_batch.batch_size + num_tokens = forward_batch.seq_lens_sum # Total tokens across all sequences if batch_info.permutation is None: # No reordering (e.g., triton backend): segments are in original order token_lora_indices = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) diff --git a/python/sglang/srt/lora/moe_dispatch.py b/python/sglang/srt/lora/moe_dispatch.py index 6a924b20b2f6..a34bf181af65 100644 --- a/python/sglang/srt/lora/moe_dispatch.py +++ b/python/sglang/srt/lora/moe_dispatch.py @@ -23,7 +23,7 @@ def moe_dispatch( lora_indices: torch.Tensor, num_experts: int, num_loras: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Dispatch tokens to experts for MoE computation. @@ -38,6 +38,7 @@ def moe_dispatch( sorted_token_ids: Token indices sorted by expert_id sorted_expert_ids: Corresponding expert IDs sorted_weights: Corresponding router weights + sorted_lora_ids: LoRA adapter IDs for each dispatched token """ num_tokens, top_k = topk_ids.shape device = topk_ids.device @@ -46,15 +47,17 @@ def moe_dispatch( flat_topk_ids = topk_ids.flatten() flat_topk_weights = topk_weights.flatten() flat_token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k) + flat_lora_ids = lora_indices.repeat_interleave(top_k) # Sort by expert_id only (each expert uses same LoRA adapter logic) - composite_key = flat_topk_ids - - # Sort by expert_id to group tokens by expert - sorted_indices = torch.argsort(composite_key) + sorted_indices = torch.argsort(flat_topk_ids) sorted_token_ids = flat_token_ids[sorted_indices] sorted_expert_ids = flat_topk_ids[sorted_indices] sorted_weights = flat_topk_weights[sorted_indices] - return sorted_token_ids, sorted_expert_ids, sorted_weights + if flat_lora_ids.shape != sorted_indices.shape: + y = 1 # need to pause + sorted_lora_ids = flat_lora_ids[sorted_indices] + + return sorted_token_ids, sorted_expert_ids, sorted_weights, sorted_lora_ids diff --git a/test/srt/lora/test_lora_moe.py b/test/srt/lora/test_lora_moe.py index e59cdb26717e..08bc1ec71637 100644 --- a/test/srt/lora/test_lora_moe.py +++ b/test/srt/lora/test_lora_moe.py @@ -15,32 +15,102 @@ """ Test MoE LoRA implementation by comparing against HuggingFace. +This test file verifies that SGLang's MoE LoRA implementation produces the same +outputs as HuggingFace's PEFT library for MoE models. + +WORKFLOW: +1. Run SGLang test in first session (loads SGLang model, saves results to temp file) +2. Run HF test in second session (loads HF model, saves results to temp file) +3. Compare saved results from both sessions + +This avoids loading both models simultaneously, saving GPU memory. + Usage: - python -m pytest test/srt/lora/test_lora_moe.py -v + # Run basic functionality test (no HF comparison) + python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_basic_functionality -v + + # Run full HF comparison test (requires model and LoRA adapter) + # This will run SGLang and HF tests separately, then compare results + python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_qwen15 -v + +Manual Testing: + # Run full SRT vs HF comparison (sequential, saves memory) python test_lora_moe.py --debug + + # Or run individual components separately: + # Terminal 1: Run SGLang test and save results + python -c " + from test.srt.lora.test_lora_moe import TestMoELoRA + import torch + import tempfile + import os + + test = TestMoELoRA() + model_case = test.MOE_LORA_TEST_CASES[0] + + with tempfile.TemporaryDirectory() as temp_dir: + result_file = os.path.join(temp_dir, 'srt_results.pkl') + test._run_srt_test(result_file, model_case, torch.float16) + print(f'SGLang results saved to {result_file}') + input('Press Enter after copying result file...') + " + + # Terminal 2: Run HF test and save results + python -c " + from test.srt.lora.test_lora_moe import TestMoELoRA + import torch + import tempfile + import os + + test = TestMoELoRA() + model_case = test.MOE_LORA_TEST_CASES[0] + + with tempfile.TemporaryDirectory() as temp_dir: + result_file = os.path.join(temp_dir, 'hf_results.pkl') + test._run_hf_test(result_file, model_case, torch.float16) + print(f'HF results saved to {result_file}') + input('Press Enter after copying result file...') + " + +Requirements: + - Qwen/Qwen1.5-MoE-A2.7B model + - sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest LoRA adapter + - Sufficient GPU memory (tests run in separate sessions) + - Uses same memory configuration as launch.json for compatibility """ +import json import logging import multiprocessing as mp +import os import pickle +import psutil +import random import tempfile -import os +import time import torch import unittest +from pathlib import Path from utils import LoRAModelCase, LoRAAdaptor, ensure_reproducibility + from sglang.test.runners import HFRunner, SRTRunner from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci -TEST_PROMPTS = [ +# Test prompts for MoE LoRA comparison +TEST_MOE_PROMPTS = [ "The capital of France is Paris. The capital of", "Explain what mixture of experts means in machine learning.", + "Write a short poem about artificial intelligence and large language models.", + "What are the benefits of using MoE architectures in transformers?", ] +# MoE model test cases with LoRA adapters MOE_LORA_TEST_CASES = [ LoRAModelCase( base="Qwen/Qwen1.5-MoE-A2.7B", adaptors=[ + # Use a real LoRA adapter path - replace with actual path when testing LoRAAdaptor( name="sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest", prefill_tolerance=1e-1, @@ -54,6 +124,7 @@ rouge_l_tolerance=1.0, max_loras_per_batch=1, ), + # Add more MoE models here when available ] @@ -62,16 +133,40 @@ class TestMoELoRA(CustomTestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + # Set up detailed logging + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) self.logger = logging.getLogger(__name__) - def _get_server_args(self): - """Get SGLang server arguments.""" + def _log_system_info(self): + """Log system and GPU memory information.""" + try: + # CPU memory + memory = psutil.virtual_memory() + self.logger.info(f"System Memory: {memory.available / (1024**3):.2f}GB available / {memory.total / (1024**3):.2f}GB total") + + # GPU memory if available + if torch.cuda.is_available(): + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + allocated = torch.cuda.memory_allocated(0) / (1024**3) + reserved = torch.cuda.memory_reserved(0) / (1024**3) + self.logger.info(f"GPU Memory: {gpu_memory:.2f}GB total, {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") + else: + self.logger.warning("CUDA not available") + except Exception as e: + self.logger.warning(f"Could not get system info: {e}") + + def _get_srt_server_args(self): + """Get SGLang server arguments from launch.json config.""" return { "quantization": "fp8", "disable_radix_cache": True, "lora_backend": "csgmv", "max_lora_chunk_size": 16, + "port": 30000, + "host": "127.0.0.1", "max_loras_per_batch": 1, "tp_size": 2, "max_total_tokens": 128, @@ -81,13 +176,30 @@ def _get_server_args(self): } def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): - """Run SGLang test and save results.""" + """Run SGLang test and save results to file.""" try: - server_args = self._get_server_args() - adaptor_names = [a.name for a in model_case.adaptors] + self.logger.info("=== Starting SGLang MoE LoRA test ===") + self._log_system_info() + + base_path = model_case.base + adaptor_names = [adaptor.name for adaptor in model_case.adaptors] + self.logger.info(f"Model: {base_path}") + self.logger.info(f"LoRA adapters: {adaptor_names}") + self.logger.info(f"dtype: {torch_dtype}") + + server_args = self._get_srt_server_args() + server_args["model_path"] = base_path + server_args["lora_paths"] = adaptor_names + server_args["enable_lora"] = True + + self.logger.info(f"Server args: {server_args}") + self.logger.info("Creating SRTRunner...") + + # Initialize SGLang runner with launch.json args + self.logger.info("Initializing SRTRunner...") srt_runner = SRTRunner( - model_path=model_case.base, + model_path=base_path, torch_dtype=torch_dtype, model_type="generation", tp_size=server_args["tp_size"], @@ -97,236 +209,619 @@ def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32) max_lora_chunk_size=server_args["max_lora_chunk_size"], max_running_requests=server_args["max_running_requests"], lora_backend=server_args["lora_backend"], + disable_cuda_graph=False, disable_radix_cache=server_args["disable_radix_cache"], max_total_tokens=server_args["max_total_tokens"], page_size=server_args["page_size"], mem_fraction_static=server_args["mem_fraction_static"], + sleep_on_idle=True, ) + self.logger.info("SRTRunner created successfully") results = {} + self._log_system_info() # Check memory after runner creation + + # Test with different batch configurations + test_configs = [ + {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA + {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA + ] + + self.logger.info("Entering SRTRunner context manager (this loads the model)...") with srt_runner: - for batch_size in [1, 2]: - prompts = TEST_PROMPTS[:batch_size] - lora_paths = [adaptor_names[0]] * batch_size + self.logger.info("SRTRunner context entered - model should be loaded") + self._log_system_info() # Check memory after model loading + + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + + # Use fixed prompts for reproducibility + prompts = TEST_MOE_PROMPTS[:batch_size] + + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + self.logger.info(f"=== Testing config: {config_key} ===") + self.logger.info(f"Batch size: {batch_size}") + self.logger.info(f"LoRA paths: {lora_paths}") + self.logger.info(f"Prompts: {prompts}") + + # Ensure reproducibility + self.logger.info("Ensuring reproducibility...") ensure_reproducibility() - outputs = srt_runner.batch_forward(prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths) - results[f"batch_{batch_size}"] = { + # Run SGLang + self.logger.info("Running batch_forward...") + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + self.logger.info("batch_forward completed successfully") + + # Print responses + self.logger.info("=== SRT Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): + self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + self.logger.info(f"SRT Output {i+1}: {output}") + self.logger.info("") + + results[config_key] = { "prompts": prompts, "lora_paths": lora_paths, - "outputs": { - "output_strs": outputs.output_strs, - "top_input_logprobs": outputs.top_input_logprobs, - "top_output_logprobs": outputs.top_output_logprobs, + "srt_outputs": { + "output_strs": srt_outputs.output_strs, + "top_input_logprobs": srt_outputs.top_input_logprobs, + "top_output_logprobs": srt_outputs.top_output_logprobs, } } + self.logger.info(f"Results saved for config: {config_key}") + # Save results + self.logger.info(f"Saving results to {result_file}") with open(result_file, 'wb') as f: pickle.dump(results, f) + self.logger.info(f"SGLang results saved successfully to {result_file}") + # Force GPU memory cleanup after context manager exit + self.logger.info("Forcing GPU memory cleanup...") + import torch torch.cuda.empty_cache() + torch.cuda.synchronize() + self._log_system_info() # Log memory after cleanup except Exception as e: + self.logger.error(f"SGLang test failed: {e}") import traceback + self.logger.error(f"Full traceback: {traceback.format_exc()}") with open(result_file, 'wb') as f: pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): - """Run HuggingFace test and save results.""" + """Run HuggingFace test and save results to file.""" try: - server_args = self._get_server_args() - adaptor_names = [a.name for a in model_case.adaptors] + self.logger.info("=== Starting HF MoE LoRA test ===") + self._log_system_info() + + base_path = model_case.base + adaptor_names = [adaptor.name for adaptor in model_case.adaptors] + + self.logger.info(f"Model: {base_path}") + self.logger.info(f"LoRA adapters: {adaptor_names}") + self.logger.info(f"dtype: {torch_dtype}") + # Get server args for consistency + server_args = self._get_srt_server_args() + + # Initialize HF runner + self.logger.info("Creating HFRunner...") hf_runner = HFRunner( - model_case.base, + base_path, torch_dtype=torch_dtype, model_type="generation", - trust_remote_code=True, - quantization=server_args["quantization"], - device_map="auto", + trust_remote_code=True, # Match SRTRunner behavior + quantization=server_args["quantization"], # Enable quantization if specified + device_map="auto", # Distribute across available GPUs (like tp_size for SRT) ) + self.logger.info("HFRunner created successfully") + self._log_system_info() # Check memory after runner creation results = {} + + # Test with different batch configurations + test_configs = [ + {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA + {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA + ] + + self.logger.info("Entering HFRunner context manager (this loads the model)...") with hf_runner: - for batch_size in [1, 2]: - prompts = TEST_PROMPTS[:batch_size] - lora_paths = [adaptor_names[0]] * batch_size + self.logger.info("HFRunner context entered - model should be loaded") + self._log_system_info() # Check memory after model loading + + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + + # Use fixed prompts for reproducibility + prompts = TEST_MOE_PROMPTS[:batch_size] + + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + self.logger.info(f"=== HF Testing config: {config_key} ===") + self.logger.info(f"Batch size: {batch_size}") + self.logger.info(f"LoRA paths: {lora_paths}") + self.logger.info(f"Prompts: {prompts}") + + # Ensure reproducibility + self.logger.info("Ensuring reproducibility...") ensure_reproducibility() - outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths) - results[f"batch_{batch_size}"] = { + # Run HuggingFace + self.logger.info("Running HF forward...") + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + self.logger.info("HF forward completed successfully") + + # Print responses + self.logger.info("=== HF Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): + self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + self.logger.info(f"HF Output {i+1}: {output}") + self.logger.info("") + + results[config_key] = { "prompts": prompts, "lora_paths": lora_paths, - "outputs": { - "output_strs": outputs.output_strs, - "top_input_logprobs": outputs.top_input_logprobs, - "top_output_logprobs": outputs.top_output_logprobs, + "hf_outputs": { + "output_strs": hf_outputs.output_strs, + "top_input_logprobs": hf_outputs.top_input_logprobs, + "top_output_logprobs": hf_outputs.top_output_logprobs, } } + self.logger.info(f"HF results saved for config: {config_key}") + # Save results + self.logger.info(f"Saving HF results to {result_file}") with open(result_file, 'wb') as f: pickle.dump(results, f) + self.logger.info(f"HF results saved successfully to {result_file}") + # Force GPU memory cleanup after context manager exit + self.logger.info("Forcing GPU memory cleanup...") + import torch torch.cuda.empty_cache() + torch.cuda.synchronize() + self._log_system_info() # Log memory after cleanup except Exception as e: + self.logger.error(f"HF test failed: {e}") import traceback + self.logger.error(f"Full traceback: {traceback.format_exc()}") with open(result_file, 'wb') as f: pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) + def _run_moe_lora_comparison(self, model_case: LoRAModelCase, torch_dtype, max_new_tokens=32): + """Run LoRA comparison test by loading saved results from separate sessions.""" + self.logger.info("=== Starting MoE LoRA comparison test ===") + self._log_system_info() + + base_path = model_case.base + self.logger.info(f"Model: {base_path}, dtype: {torch_dtype}") + + # Create temp directory for results + with tempfile.TemporaryDirectory() as temp_dir: + srt_result_file = os.path.join(temp_dir, 'srt_results.pkl') + hf_result_file = os.path.join(temp_dir, 'hf_results.pkl') + + self.logger.info(f"Results directory: {temp_dir}") + self.logger.info(f"SRT result file: {srt_result_file}") + self.logger.info(f"HF result file: {hf_result_file}") + + # Check if results already exist + srt_done = os.path.exists(srt_result_file) + hf_done = os.path.exists(hf_result_file) + + self.logger.info(f"SRT results exist: {srt_done}") + self.logger.info(f"HF results exist: {hf_done}") + + if not srt_done: + self.logger.info("Running SGLang test...") + self._run_srt_test(srt_result_file, model_case, torch_dtype, max_new_tokens) + srt_done = True + + if not hf_done: + self.logger.info("Running HuggingFace test...") + self._run_hf_test(hf_result_file, model_case, torch_dtype, max_new_tokens) + hf_done = True + + # Load results + self.logger.info("Loading SRT results...") + if srt_done: + with open(srt_result_file, 'rb') as f: + srt_results = pickle.load(f) + if "error" in srt_results: + error_msg = f"SGLang test failed: {srt_results['error']}" + if "traceback" in srt_results: + error_msg += f"\nTraceback: {srt_results['traceback']}" + self.fail(error_msg) + self.logger.info("SRT results loaded successfully") + + self.logger.info("Loading HF results...") + if hf_done: + with open(hf_result_file, 'rb') as f: + hf_results = pickle.load(f) + if "error" in hf_results: + error_msg = f"HF test failed: {hf_results['error']}" + if "traceback" in hf_results: + error_msg += f"\nTraceback: {hf_results['traceback']}" + self.fail(error_msg) + self.logger.info("HF results loaded successfully") + + # Compare results + if srt_done and hf_done: + self.logger.info("Starting result comparison...") + for config_key in srt_results.keys(): + if config_key in hf_results: + self.logger.info(f"Comparing config: {config_key}") + srt_data = srt_results[config_key] + hf_data = hf_results[config_key] + + self._compare_outputs( + srt_data["srt_outputs"], hf_data["hf_outputs"], + model_case, srt_data["prompts"], srt_data["lora_paths"] + ) + else: + self.logger.warning(f"No HF results for config: {config_key}") + + self.logger.info("Comparison completed successfully") + def _compare_outputs(self, srt_outputs, hf_outputs, model_case, prompts, lora_paths): """Compare SGLang and HF outputs.""" for i, (prompt, lora_path) in enumerate(zip(prompts, lora_paths)): - srt_str = srt_outputs["output_strs"][i].strip() - hf_str = hf_outputs["output_strs"][i].strip() + print(f"\nRequest {i}: lora_path='{lora_path}'") + print(f"Prompt: {prompt[:50]}{'...' if len(prompt) > 50 else ''}") + + # Compare output strings + srt_output_str = srt_outputs.output_strs[i].strip() + hf_output_str = hf_outputs.output_strs[i].strip() + + print(f"SRT output: {srt_output_str[:100]}{'...' if len(srt_output_str) > 100 else ''}") + print(f"HF output: {hf_output_str[:100]}{'...' if len(hf_output_str) > 100 else ''}") + + # Calculate ROUGE-L similarity + rouge_l = calculate_rouge_l(srt_output_str, hf_output_str) + print(f"ROUGE-L similarity: {rouge_l:.4f}") + + # Check ROUGE-L tolerance + self.assertGreaterEqual( + rouge_l, + model_case.rouge_l_tolerance, + f"ROUGE-L similarity {rouge_l:.4f} below tolerance {model_case.rouge_l_tolerance} " + f"for request {i} with lora_path='{lora_path}'" + ) - rouge_l = calculate_rouge_l(srt_str, hf_str) - self.assertGreaterEqual(rouge_l, model_case.rouge_l_tolerance, - f"ROUGE-L {rouge_l:.4f} below tolerance for request {i}") + # Compare logprobs if available + if hasattr(srt_outputs, 'top_input_logprobs') and hasattr(hf_outputs, 'top_input_logprobs'): + if srt_outputs.top_input_logprobs[i] is not None and hf_outputs.top_input_logprobs[i] is not None: + import torch + srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i]) + hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i]) + + max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill)) + print(f"Max prefill logprob diff: {max_prefill_diff:.6f}") + + # Check prefill tolerance + prefill_tol = model_case.prefill_tolerance + self.assertLessEqual( + max_prefill_diff, + prefill_tol, + f"Prefill logprob diff {max_prefill_diff:.6f} exceeds tolerance {prefill_tol} " + f"for request {i} with lora_path='{lora_path}'" + ) + + if hasattr(srt_outputs, 'top_output_logprobs') and hasattr(hf_outputs, 'top_output_logprobs'): + if srt_outputs.top_output_logprobs[i] is not None and hf_outputs.top_output_logprobs[i] is not None: + import torch + srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i]) + hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i]) + + max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode)) + print(f"Max decode logprob diff: {max_decode_diff:.6f}") + + # Check decode tolerance + decode_tol = model_case.decode_tolerance + self.assertLessEqual( + max_decode_diff, + decode_tol, + f"Decode logprob diff {max_decode_diff:.6f} exceeds tolerance {decode_tol} " + f"for request {i} with lora_path='{lora_path}'" + ) def test_moe_lora_qwen15(self): """Test LoRA on Qwen1.5-MoE-A2.7B.""" + self.logger.info("=== Starting test_moe_lora_qwen15 ===") + if is_in_ci(): - self.skipTest("Skipping MoE LoRA test in CI") + self.logger.info("Skipping MoE LoRA test in CI environment") + self.skipTest("Skipping MoE LoRA test in CI environment") model_case = MOE_LORA_TEST_CASES[0] - - with tempfile.TemporaryDirectory() as temp_dir: - srt_file = os.path.join(temp_dir, 'srt.pkl') - hf_file = os.path.join(temp_dir, 'hf.pkl') - - self._run_srt_test(srt_file, model_case, torch.float16) - self._run_hf_test(hf_file, model_case, torch.float16) - - with open(srt_file, 'rb') as f: - srt_results = pickle.load(f) - with open(hf_file, 'rb') as f: - hf_results = pickle.load(f) - - if "error" in srt_results: - self.fail(f"SRT failed: {srt_results['error']}") - if "error" in hf_results: - self.fail(f"HF failed: {hf_results['error']}") - - for key in srt_results: - if key in hf_results: - self._compare_outputs( - srt_results[key]["outputs"], - hf_results[key]["outputs"], - model_case, - srt_results[key]["prompts"], - srt_results[key]["lora_paths"] - ) + self.logger.info(f"Using model case: {model_case.base}") + + # Test with different dtypes + import torch + for torch_dtype in [torch.float16, torch.bfloat16]: + self.logger.info(f"Testing dtype: {torch_dtype}") + with self.subTest(dtype=torch_dtype): + try: + self._run_moe_lora_comparison(model_case, torch_dtype) + except Exception as e: + self.logger.error(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") + import traceback + self.logger.error(f"Traceback: {traceback.format_exc()}") + self.fail(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") def test_moe_lora_basic_functionality(self): """Basic functionality test for MoE LoRA dispatch.""" + # This test focuses on the core dispatch logic without full HF comparison + # Useful for debugging the MoE LoRA implementation + + import torch from sglang.srt.lora.moe_dispatch import moe_dispatch - topk_ids = torch.tensor([[0, 1], [2, 3], [1, 4], [5, 6]], dtype=torch.int32) + # Create test data + num_tokens = 4 + top_k = 2 + num_experts = 8 + + # Mock top-k routing results + topk_ids = torch.tensor([ + [0, 1], # token 0 routes to experts 0, 1 + [2, 3], # token 1 routes to experts 2, 3 + [1, 4], # token 2 routes to experts 1, 4 + [5, 6], # token 3 routes to experts 5, 6 + ], dtype=torch.int32) + topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) - lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) + # Mock LoRA indices (one per token) + lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) # tokens 0,1 use lora 0; tokens 2,3 use lora 1 + + # Run dispatch token_ids, expert_ids, weights = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, lora_indices=lora_indices, - num_experts=8, + num_experts=num_experts, num_loras=2, ) + # Verify results + # Should have 4 tokens * 2 experts each = 8 dispatched entries self.assertEqual(len(token_ids), 8) self.assertEqual(len(expert_ids), 8) self.assertEqual(len(weights), 8) + # Check that tokens are grouped by expert (not by LoRA) + # All tokens going to expert 0 should come first, then expert 1, etc. + unique_experts, expert_counts = torch.unique_consecutive(expert_ids, return_counts=True) + self.assertTrue(torch.all(expert_counts >= 1)) # Each expert should have at least one token + + print(f"Dispatch successful: {len(token_ids)} dispatched tokens to experts {unique_experts.tolist()}") + def debug_full_comparison(): """Debug helper to run full SRT vs HF comparison.""" + import torch + import tempfile + import os + + # Set up logging for debugging + logging.basicConfig( + level=logging.INFO, # Less verbose for debug script + format='%(asctime)s - %(levelname)s - %(message)s' + ) + test = TestMoELoRA() model_case = MOE_LORA_TEST_CASES[0] - server_args = test._get_server_args() - adaptor_names = [a.name for a in model_case.adaptors] - srt_results, hf_results = {}, {} + print("=" * 80) + print("DEBUG: Running Full SRT vs HF Comparison") + print("=" * 80) + + srt_results = {} + hf_results = {} + + # Test configurations + test_configs = [ + {"batch_size": 1, "lora_paths": [model_case.adaptors[0].name]}, # Single request, single LoRA + {"batch_size": 2, "lora_paths": [model_case.adaptors[0].name, model_case.adaptors[0].name]}, # Multiple requests, same LoRA + ] try: - # SRT tests - print("Running SGLang tests...") + # Phase 1: Run SRT tests + print("\n" + "="*50) + print("PHASE 1: Running SGLang (SRT) Tests") + print("="*50) + + server_args = test._get_srt_server_args() + print(f"Server args: quantization={server_args['quantization']}, tp_size={server_args['tp_size']}, mem_fraction_static={server_args['mem_fraction_static']}") + srt_runner = SRTRunner( model_path=model_case.base, torch_dtype=torch.float16, model_type="generation", tp_size=server_args["tp_size"], - lora_paths=adaptor_names, + lora_paths=[adaptor.name for adaptor in model_case.adaptors], max_loras_per_batch=server_args["max_loras_per_batch"], quantization=server_args["quantization"], max_lora_chunk_size=server_args["max_lora_chunk_size"], max_running_requests=server_args["max_running_requests"], lora_backend=server_args["lora_backend"], + disable_cuda_graph=False, disable_radix_cache=server_args["disable_radix_cache"], max_total_tokens=server_args["max_total_tokens"], page_size=server_args["page_size"], mem_fraction_static=server_args["mem_fraction_static"], + sleep_on_idle=True, ) + print("✓ SRT runner created successfully") with srt_runner: - for batch_size in [1, 2]: - prompts = TEST_PROMPTS[:batch_size] - lora_paths = [adaptor_names[0]] * batch_size + print("✓ SRT model loaded successfully") + test._log_system_info() + + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + prompts = TEST_MOE_PROMPTS[:batch_size] + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + print(f"\n--- SRT Testing {config_key} ---") + print(f"Prompts: {prompts}") + + # Ensure reproducibility ensure_reproducibility() - outputs = srt_runner.batch_forward(prompts, max_new_tokens=32, lora_paths=lora_paths) - srt_results[f"batch_{batch_size}"] = {"prompts": prompts, "lora_paths": lora_paths, "outputs": outputs} - for i, out in enumerate(outputs.output_strs): - print(f"SRT [{batch_size}] {i}: {out}") - torch.cuda.empty_cache() + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=32, + lora_paths=lora_paths, + ) + print("✓ SRT batch_forward completed") + + # Print SRT responses + print("=== SRT Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): + print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + print(f"SRT Output {i+1}: {output}") + print("") + + srt_results[config_key] = { + "prompts": prompts, + "lora_paths": lora_paths, + "srt_outputs": srt_outputs + } + + print("✓ All SRT tests completed successfully") + + # Clear GPU memory before HF test + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("✓ GPU cache cleared before HF test") + + # Small delay to ensure memory is fully released + import time + time.sleep(2) + test._log_system_info() + + # Phase 2: Run HF tests + print("\n" + "="*50) + print("PHASE 2: Running HuggingFace (HF) Tests") + print("="*50) - # HF tests - print("\nRunning HuggingFace tests...") + server_args = test._get_srt_server_args() hf_runner = HFRunner( model_case.base, torch_dtype=torch.float16, model_type="generation", - trust_remote_code=True, - quantization=server_args["quantization"], - device_map="auto", + trust_remote_code=True, # Match SRTRunner behavior + quantization=server_args["quantization"], # Enable quantization if specified + device_map="auto", # Distribute across available GPUs ) + print("✓ HF runner created successfully") with hf_runner: - for batch_size in [1, 2]: - prompts = TEST_PROMPTS[:batch_size] - lora_paths = [adaptor_names[0]] * batch_size + print("✓ HF model loaded successfully") + test._log_system_info() + + for config in test_configs: + batch_size = config["batch_size"] + lora_paths = config["lora_paths"] + prompts = TEST_MOE_PROMPTS[:batch_size] + config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" + + print(f"\n--- HF Testing {config_key} ---") + print(f"Prompts: {prompts}") + + # Ensure reproducibility ensure_reproducibility() - outputs = hf_runner.forward(prompts, max_new_tokens=32, lora_paths=lora_paths) - hf_results[f"batch_{batch_size}"] = {"prompts": prompts, "lora_paths": lora_paths, "outputs": outputs} - for i, out in enumerate(outputs.output_strs): - print(f"HF [{batch_size}] {i}: {out}") - - torch.cuda.empty_cache() - - # Compare - print("\nComparing outputs...") - for key in srt_results: - srt_data, hf_data = srt_results[key], hf_results[key] - for i in range(len(srt_data["prompts"])): - rouge_l = calculate_rouge_l( - srt_data["outputs"].output_strs[i], - hf_data["outputs"].output_strs[i] + + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=32, + lora_paths=lora_paths, ) - print(f"{key} request {i}: ROUGE-L = {rouge_l:.4f}") + print("✓ HF forward completed") + + # Print HF responses + print("=== HF Generated Responses ===") + for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): + print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + print(f"HF Output {i+1}: {output}") + print("") - print("\nAll comparisons completed!") - return True + hf_results[config_key] = { + "prompts": prompts, + "lora_paths": lora_paths, + "hf_outputs": hf_outputs + } + + print("✓ All HF tests completed successfully") + + # Force final GPU memory cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("✓ Final GPU memory cleanup completed") + test._log_system_info() + + # Phase 3: Compare results + print("\n" + "="*50) + print("PHASE 3: Comparing SRT vs HF Outputs") + print("="*50) + + for config_key in srt_results.keys(): + if config_key in hf_results: + srt_data = srt_results[config_key] + hf_data = hf_results[config_key] + + print(f"\n{'='*30} Comparing {config_key} {'='*30}") + try: + test._compare_outputs( + srt_data["srt_outputs"], hf_data["hf_outputs"], + model_case, srt_data["prompts"], srt_data["lora_paths"] + ) + print(f"✓ Comparison passed for {config_key}") + except AssertionError as e: + print(f"✗ Comparison failed for {config_key}: {e}") + raise + else: + print(f"✗ No HF results for config: {config_key}") + + print("\n" + "="*80) + print("🎉 ALL COMPARISONS PASSED! SRT and HF outputs match!") + print("="*80) except Exception as e: + print(f"\n💥 DEBUG SCRIPT FAILED: {e}") import traceback - print(f"Failed: {e}") traceback.print_exc() return False + return True + if __name__ == "__main__": import sys + if len(sys.argv) > 1 and sys.argv[1] == "--debug": + # Run debug tests instead of unittest success = debug_full_comparison() sys.exit(0 if success else 1) else: @@ -334,4 +829,5 @@ def debug_full_comparison(): mp.set_start_method("spawn") except RuntimeError: pass + unittest.main() From af3c75832efd486ead3c7c4f72bef1a3b9dcf08c Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 29 Nov 2025 01:27:34 +0000 Subject: [PATCH 030/150] remove unnecessary code --- python/sglang/srt/lora/lora_manager.py | 13 ++----------- python/sglang/srt/lora/mem_pool.py | 14 +++----------- 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 2a8384459616..a1eccfa9c9a2 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -325,16 +325,14 @@ def update_lora_info(self): if isinstance(module, FusedMoEWithLoRA) and all(x in self.target_modules for x in ['gate_up_proj', 'down_proj']): module.set_lora_info( self.memory_pool.get_tensor( - target_module='gate_up_proj', + target_module='gate_up_proj_moe', layer_id=layer_id, lora_type=LoRAType.LORA_A, - context='moe', ), self.memory_pool.get_tensor( - target_module='down_proj', + target_module='down_proj_moe', layer_id=layer_id, lora_type=LoRAType.LORA_B, - context='moe', ), ) continue @@ -343,23 +341,16 @@ def update_lora_info(self): module_name, self.memory_pool.target_modules ) - # Determine context based on module name - context = None - if isinstance(module, FusedMoEWithLoRA): - context = "moe" - module.set_lora_info( self.memory_pool.get_tensor( target_module=target_module, layer_id=layer_id, lora_type=LoRAType.LORA_A, - context=context, ), self.memory_pool.get_tensor( target_module=target_module, layer_id=layer_id, lora_type=LoRAType.LORA_B, - context=context, ), ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 76985439e24e..1905e6567ad0 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -480,30 +480,22 @@ def load_lora_weight_tensor( load_lora_weight_tensor(buffer_view, weights) def get_tensor( - self, target_module: str, layer_id: int, lora_type: LoRAType, context: str = None + self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: """ Get LoRA tensor buffer (automatically handles both 3D and 4D tensors). Args: - target_module: Target module name (e.g., 'gate_up_proj') + target_module: Target module name (e.g., 'gate_up_proj' or 'gate_up_proj_moe' for MoE) layer_id: Layer index lora_type: LoRAType.LORA_A or LoRAType.LORA_B - context: Optional context hint ('moe' or None for auto-detect) Returns: - 3D tensor [num_loras, rank, hidden] for standard modules - 4D tensor [num_loras, num_experts, rank, hidden] for MoE modules """ buffer_dict = self.A_buffer if lora_type == LoRAType.LORA_A else self.B_buffer - - # Handle context-specific buffer selection for ambiguous modules - ambiguous_modules = {"gate_up_proj", "down_proj"} - if target_module in ambiguous_modules: - if context == "moe" and f"{target_module}_moe" in buffer_dict: - return buffer_dict[f"{target_module}_moe"][layer_id] - - # Fall back to original key for non-ambiguous modules + return buffer_dict[target_module][layer_id] def get_buffer_id(self, lora_uid: str): From d3b27ee61415e6b4fc2e0c351ecd670a1dac242b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 29 Nov 2025 01:31:40 +0000 Subject: [PATCH 031/150] rename vars for clarity --- python/sglang/srt/layers/moe/lora_moe.py | 4 +--- python/sglang/srt/lora/moe_dispatch.py | 13 +++---------- test/srt/lora/test_lora_moe.py | 7 +++---- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/moe/lora_moe.py b/python/sglang/srt/layers/moe/lora_moe.py index 2ea43c6380c6..07b113013d65 100644 --- a/python/sglang/srt/layers/moe/lora_moe.py +++ b/python/sglang/srt/layers/moe/lora_moe.py @@ -104,12 +104,10 @@ def _compute_lora_delta( num_loras = self.lora_a_weights.shape[0] # Dispatch tokens to experts - token_ids, expert_ids, _, lora_ids = moe_dispatch( + token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, lora_indices=lora_indices, - num_experts=num_experts, - num_loras=num_loras, ) diff --git a/python/sglang/srt/lora/moe_dispatch.py b/python/sglang/srt/lora/moe_dispatch.py index a34bf181af65..de7106a489b9 100644 --- a/python/sglang/srt/lora/moe_dispatch.py +++ b/python/sglang/srt/lora/moe_dispatch.py @@ -21,8 +21,6 @@ def moe_dispatch( topk_ids: torch.Tensor, topk_weights: torch.Tensor, lora_indices: torch.Tensor, - num_experts: int, - num_loras: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Dispatch tokens to experts for MoE computation. @@ -31,13 +29,11 @@ def moe_dispatch( topk_ids: [num_tokens, top_k] - Expert IDs selected by router topk_weights: [num_tokens, top_k] - Router weights lora_indices: [num_tokens] - LoRA adapter ID for each token - num_experts: Total number of experts - num_loras: Total number of LoRA adapters Returns: sorted_token_ids: Token indices sorted by expert_id sorted_expert_ids: Corresponding expert IDs - sorted_weights: Corresponding router weights + sorted_topk_weights: Corresponding router weights sorted_lora_ids: LoRA adapter IDs for each dispatched token """ num_tokens, top_k = topk_ids.shape @@ -54,10 +50,7 @@ def moe_dispatch( sorted_token_ids = flat_token_ids[sorted_indices] sorted_expert_ids = flat_topk_ids[sorted_indices] - sorted_weights = flat_topk_weights[sorted_indices] - - if flat_lora_ids.shape != sorted_indices.shape: - y = 1 # need to pause + sorted_topk_weights = flat_topk_weights[sorted_indices] sorted_lora_ids = flat_lora_ids[sorted_indices] - return sorted_token_ids, sorted_expert_ids, sorted_weights, sorted_lora_ids + return sorted_token_ids, sorted_expert_ids, sorted_topk_weights, sorted_lora_ids diff --git a/test/srt/lora/test_lora_moe.py b/test/srt/lora/test_lora_moe.py index 08bc1ec71637..5cc366968e1a 100644 --- a/test/srt/lora/test_lora_moe.py +++ b/test/srt/lora/test_lora_moe.py @@ -592,19 +592,18 @@ def test_moe_lora_basic_functionality(self): lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) # tokens 0,1 use lora 0; tokens 2,3 use lora 1 # Run dispatch - token_ids, expert_ids, weights = moe_dispatch( + token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, lora_indices=lora_indices, - num_experts=num_experts, - num_loras=2, ) # Verify results # Should have 4 tokens * 2 experts each = 8 dispatched entries self.assertEqual(len(token_ids), 8) self.assertEqual(len(expert_ids), 8) - self.assertEqual(len(weights), 8) + self.assertEqual(len(sorted_topk_weights), 8) + self.assertEqual(len(lora_ids), 8) # Check that tokens are grouped by expert (not by LoRA) # All tokens going to expert 0 should come first, then expert 1, etc. From 53dc64dcd979714d3b67ba86c41b7594c03cffd0 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 13 Dec 2025 00:33:23 +0000 Subject: [PATCH 032/150] move from lora_moe.py to layers.py --- python/sglang/srt/layers/moe/lora_moe.py | 127 ----------------------- python/sglang/srt/lora/layers.py | 114 +++++++++++++++++++- python/sglang/srt/lora/lora_manager.py | 2 +- python/sglang/srt/lora/mem_pool.py | 2 +- 4 files changed, 115 insertions(+), 130 deletions(-) delete mode 100644 python/sglang/srt/layers/moe/lora_moe.py diff --git a/python/sglang/srt/layers/moe/lora_moe.py b/python/sglang/srt/layers/moe/lora_moe.py deleted file mode 100644 index 07b113013d65..000000000000 --- a/python/sglang/srt/layers/moe/lora_moe.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""FusedMoE layer with LoRA support.""" - -import torch -from torch import nn - -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import TopKOutput -from sglang.srt.lora.backend.base_backend import BaseLoRABackend - - -class FusedMoEWithLoRA(nn.Module): - """ - Wrapper around FusedMoE that adds parallel LoRA computation. - - Design: Base MoE and LoRA Delta run independently and merge at the end. - This preserves SGLang's existing 3-stage MoE architecture unchanged. - """ - - def __init__( - self, - base_moe: FusedMoE, - lora_backend: BaseLoRABackend, - ): - super().__init__() - self.base_moe = base_moe - self.lora_backend = lora_backend - self.lora_enabled = False - - # LoRA tensors will be set by LoRAManager - self.lora_a_weights = None - self.lora_b_weights = None - - def set_lora_info( - self, - lora_a_weights: torch.Tensor, - lora_b_weights: torch.Tensor, - ): - """Set LoRA weight tensors from memory pool.""" - self.lora_enabled = True - self.lora_a_weights = lora_a_weights - self.lora_b_weights = lora_b_weights - - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): - """ - Forward pass with parallel LoRA computation. - - Flow: - 1. Base MoE forward - 2. Parallel LoRA delta computation (if enabled, added in-place) - 3. Return modified base_output - """ - # Run base MoE - base_output = self.base_moe.forward(hidden_states, topk_output, **kwargs) - - # If LoRA is enabled, compute delta and add in-place for memory efficiency - if self.lora_enabled and self.lora_a_weights is not None: - self._compute_lora_delta(hidden_states, topk_output, base_output) - - return base_output - - def _compute_lora_delta( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - base_output: torch.Tensor, - ) -> None: - """ - Compute LoRA delta using per-expert LoRA weights and add to base_output in-place. - - Dispatch tokens to experts and compute per-expert deltas. - """ - from sglang.srt.lora.moe_dispatch import moe_dispatch - from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( - per_expert_lora_forward, - ) - - # Get dispatch info from TopKOutput - topk_ids = topk_output.topk_ids # [num_tokens, top_k] - topk_weights = topk_output.topk_weights # [num_tokens, top_k] - - # Get LoRA batch info from backend - batch_info = self.lora_backend.batch_info - lora_ranks = batch_info.lora_ranks # [num_loras] - scalings = batch_info.scalings # [num_loras] - - # Use precomputed per-token LoRA indices from forward batch - lora_indices = self.lora_backend.forward_batch.token_lora_indices - - num_experts = self.base_moe.num_experts - num_loras = self.lora_a_weights.shape[0] - - # Dispatch tokens to experts - token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( - topk_ids=topk_ids, - topk_weights=topk_weights, - lora_indices=lora_indices, - ) - - - - # Compute per-expert LoRA forward (adds to base_output in-place) - per_expert_lora_forward( - hidden_states=hidden_states, - lora_a_weights=self.lora_a_weights, - lora_b_weights=self.lora_b_weights, - token_ids=token_ids, - expert_ids=expert_ids, - lora_ids=lora_ids, - lora_ranks=lora_ranks, - lora_scalings=scalings, - num_experts=num_experts, - base_output=base_output, - ) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 88b0a5a85671..98c89178637b 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -20,6 +20,7 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo @@ -576,11 +577,122 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): return B +class FusedMoEWithLoRA(BaseLayerWithLoRA): + """ + Wrapper around FusedMoE that adds parallel LoRA computation. + + Design: Base MoE and LoRA Delta run independently and merge at the end. + This preserves SGLang's existing 3-stage MoE architecture unchanged. + """ + + def __init__( + self, + base_moe: FusedMoE, + lora_backend: BaseLoRABackend, + ): + super().__init__(base_moe, lora_backend) + # LoRA tensors will be set by LoRAManager + self.lora_a_weights = None + self.lora_b_weights = None + + def set_lora_info( + self, + lora_a_weights: torch.Tensor, + lora_b_weights: torch.Tensor, + ): + """Set LoRA weight tensors from memory pool.""" + self.set_lora = True + self.lora_a_weights = lora_a_weights + self.lora_b_weights = lora_b_weights + + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): + """ + Forward pass with parallel LoRA computation. + + Flow: + 1. Base MoE forward + 2. Parallel LoRA delta computation (if enabled, added in-place) + 3. Return modified base_output + """ + # Run base MoE + base_output = self.base_moe.forward(hidden_states, topk_output, **kwargs) + + # If LoRA is enabled, compute delta and add in-place for memory efficiency + if self.set_lora and self.lora_a_weights is not None: + self._compute_lora_delta(hidden_states, topk_output, base_output) + + return base_output + + def _compute_lora_delta( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + base_output: torch.Tensor, + ) -> None: + """ + Compute LoRA delta using per-expert LoRA weights and add to base_output in-place. + + Dispatch tokens to experts and compute per-expert deltas. + """ + from sglang.srt.lora.moe_dispatch import moe_dispatch + from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( + per_expert_lora_forward, + ) + + # Get dispatch info from TopKOutput + topk_ids = topk_output.topk_ids # [num_tokens, top_k] + topk_weights = topk_output.topk_weights # [num_tokens, top_k] + + # Get LoRA batch info from backend + batch_info = self.lora_backend.batch_info + lora_ranks = batch_info.lora_ranks # [num_loras] + scalings = batch_info.scalings # [num_loras] + + # Use precomputed per-token LoRA indices from forward batch + lora_indices = self.lora_backend.forward_batch.token_lora_indices + + num_experts = self.base_moe.num_experts + num_loras = self.lora_a_weights.shape[0] + + # Dispatch tokens to experts + token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( + topk_ids=topk_ids, + topk_weights=topk_weights, + lora_indices=lora_indices, + ) + + + + # Compute per-expert LoRA forward (adds to base_output in-place) + per_expert_lora_forward( + hidden_states=hidden_states, + lora_a_weights=self.lora_a_weights, + lora_b_weights=self.lora_b_weights, + token_ids=token_ids, + expert_ids=expert_ids, + lora_ids=lora_ids, + lora_ranks=lora_ranks, + lora_scalings=scalings, + num_experts=num_experts, + base_output=base_output, + ) + + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): + # For MoE layers, tensor parallelism is typically not used + # Return weights unchanged + return A + + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): + # For MoE layers, tensor parallelism is typically not used + # Return weights unchanged + return B + + def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE - from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA + # FusedMoEWithLoRA is now defined in this file supported_layer_types = { # the order matters diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 1eb77cf637c4..29e125314bac 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -44,7 +44,7 @@ from sglang.srt.utils import replace_submodule from sglang.srt.utils.hf_transformers_utils import AutoConfig from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA +from sglang.srt.lora.layers import FusedMoEWithLoRA logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 5874ac2d4ee7..f1305e08933b 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -504,7 +504,7 @@ def load_lora_weight_tensor( for module_name, module in cur_layer_modules.items(): # TODO (Jonahcb): check if the code can be refactored to avoid the special handling for FusedMoEWithLoRA # Handle FusedMoEWithLoRA specially - it contains multiple target modules - from sglang.srt.layers.moe.lora_moe import FusedMoEWithLoRA + from sglang.srt.lora.layers import FusedMoEWithLoRA if isinstance(module, FusedMoEWithLoRA): # FusedMoEWithLoRA contains both gate_up_proj and down_proj moe_target_modules = ['gate_up_proj_moe', 'down_proj_moe'] From ad9c32e452c51a824fa21cbcd46275bc6f75363b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 13 Dec 2025 03:02:55 +0000 Subject: [PATCH 033/150] clean up test files --- python/sglang/srt/lora/layers.py | 11 +- python/sglang/srt/lora/lora_manager.py | 29 +- .../srt/lora/test_lora_hf_sgl_logprob_diff.py | 36 + test/srt/lora/test_lora_moe.py | 832 ------------------ 4 files changed, 54 insertions(+), 854 deletions(-) delete mode 100644 test/srt/lora/test_lora_moe.py diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 98c89178637b..d146233cf69d 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -20,6 +20,7 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo @@ -587,10 +588,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): def __init__( self, - base_moe: FusedMoE, + base_layer: nn.Module, lora_backend: BaseLoRABackend, ): - super().__init__(base_moe, lora_backend) + super().__init__(base_layer, lora_backend) # LoRA tensors will be set by LoRAManager self.lora_a_weights = None self.lora_b_weights = None @@ -615,7 +616,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs 3. Return modified base_output """ # Run base MoE - base_output = self.base_moe.forward(hidden_states, topk_output, **kwargs) + base_output = self.base_layer.forward(hidden_states, topk_output, **kwargs) # If LoRA is enabled, compute delta and add in-place for memory efficiency if self.set_lora and self.lora_a_weights is not None: @@ -651,7 +652,7 @@ def _compute_lora_delta( # Use precomputed per-token LoRA indices from forward batch lora_indices = self.lora_backend.forward_batch.token_lora_indices - num_experts = self.base_moe.num_experts + num_experts = self.base_layer.num_experts num_loras = self.lora_a_weights.shape[0] # Dispatch tokens to experts @@ -691,7 +692,7 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: - from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + # FusedMoE is now imported at the top of the file # FusedMoEWithLoRA is now defined in this file supported_layer_types = { diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 29e125314bac..bc3697752723 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -285,27 +285,22 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): # Populate per-token LoRA indices from segment information batch_info = self.lora_backend.batch_info - num_tokens = forward_batch.seq_lens_sum # Total tokens across all sequences + num_tokens = forward_batch.input_ids.shape[0] # Tokens in current forward pass + + # Create tensor and fill with adapter indices from segments + token_lora_indices_reordered = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) + seg_indptr = batch_info.seg_indptr # [num_segments + 1] + for seg_idx in range(batch_info.num_segments): + start_token = seg_indptr[seg_idx] + end_token = seg_indptr[seg_idx + 1] + lora_adapter = batch_info.weight_indices[seg_idx] + token_lora_indices_reordered[start_token:end_token] = lora_adapter + if batch_info.permutation is None: # No reordering (e.g., triton backend): segments are in original order - token_lora_indices = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) - seg_indptr = batch_info.seg_indptr # [num_segments + 1] - for seg_idx in range(batch_info.num_segments): - start_token = seg_indptr[seg_idx] - end_token = seg_indptr[seg_idx + 1] - lora_adapter = batch_info.weight_indices[seg_idx] - token_lora_indices[start_token:end_token] = lora_adapter + token_lora_indices = token_lora_indices_reordered else: # Tokens are reordered (chunked backend): need to convert back to original order - token_lora_indices_reordered = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) - seg_indptr = batch_info.seg_indptr # [num_segments + 1] - for seg_idx in range(batch_info.num_segments): - start_token = seg_indptr[seg_idx] - end_token = seg_indptr[seg_idx + 1] - lora_adapter = batch_info.weight_indices[seg_idx] - token_lora_indices_reordered[start_token:end_token] = lora_adapter - - # Convert back to original token order using inverse permutation inverse_permutation = torch.empty_like(batch_info.permutation) inverse_permutation[batch_info.permutation] = torch.arange(num_tokens, dtype=batch_info.permutation.dtype, device=batch_info.permutation.device) token_lora_indices = token_lora_indices_reordered[inverse_permutation] diff --git a/test/srt/lora/test_lora_hf_sgl_logprob_diff.py b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py index b0975fa5d666..0b20054e5f3b 100644 --- a/test/srt/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/srt/lora/test_lora_hf_sgl_logprob_diff.py @@ -543,6 +543,42 @@ def test_lora_logprob_comparison_full(self): max_new_tokens=32, ) + def test_moe_lora_logprob_comparison_basic(self): + """ + Test comparing HF and SGLang MoE LoRA logprobs with basic prompts. + """ + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large MoE models") + + model_path = "Qwen/Qwen1.5-MoE-A2.7B" + lora_paths = ["sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest"] + prompts = DEFAULT_TEST_PROMPTS[:2] # Use first 2 default prompts for basic test + + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + ) + + def test_moe_lora_logprob_comparison_full(self): + """ + Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. + """ + if is_in_ci(): + self.skipTest("Skipping in CI environment - requires large MoE models") + + model_path = "Qwen/Qwen1.5-MoE-A2.7B" + lora_paths = ["sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest"] + prompts = DEFAULT_TEST_PROMPTS + + self._run_comparison_test( + model_path=model_path, + lora_paths=lora_paths, + prompts=prompts, + max_new_tokens=32, + ) + if __name__ == "__main__": try: diff --git a/test/srt/lora/test_lora_moe.py b/test/srt/lora/test_lora_moe.py deleted file mode 100644 index 5cc366968e1a..000000000000 --- a/test/srt/lora/test_lora_moe.py +++ /dev/null @@ -1,832 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -""" -Test MoE LoRA implementation by comparing against HuggingFace. - -This test file verifies that SGLang's MoE LoRA implementation produces the same -outputs as HuggingFace's PEFT library for MoE models. - -WORKFLOW: -1. Run SGLang test in first session (loads SGLang model, saves results to temp file) -2. Run HF test in second session (loads HF model, saves results to temp file) -3. Compare saved results from both sessions - -This avoids loading both models simultaneously, saving GPU memory. - -Usage: - # Run basic functionality test (no HF comparison) - python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_basic_functionality -v - - # Run full HF comparison test (requires model and LoRA adapter) - # This will run SGLang and HF tests separately, then compare results - python -m pytest test/srt/lora/test_lora_moe.py::TestMoELoRA::test_moe_lora_qwen15 -v - -Manual Testing: - # Run full SRT vs HF comparison (sequential, saves memory) - python test_lora_moe.py --debug - - # Or run individual components separately: - # Terminal 1: Run SGLang test and save results - python -c " - from test.srt.lora.test_lora_moe import TestMoELoRA - import torch - import tempfile - import os - - test = TestMoELoRA() - model_case = test.MOE_LORA_TEST_CASES[0] - - with tempfile.TemporaryDirectory() as temp_dir: - result_file = os.path.join(temp_dir, 'srt_results.pkl') - test._run_srt_test(result_file, model_case, torch.float16) - print(f'SGLang results saved to {result_file}') - input('Press Enter after copying result file...') - " - - # Terminal 2: Run HF test and save results - python -c " - from test.srt.lora.test_lora_moe import TestMoELoRA - import torch - import tempfile - import os - - test = TestMoELoRA() - model_case = test.MOE_LORA_TEST_CASES[0] - - with tempfile.TemporaryDirectory() as temp_dir: - result_file = os.path.join(temp_dir, 'hf_results.pkl') - test._run_hf_test(result_file, model_case, torch.float16) - print(f'HF results saved to {result_file}') - input('Press Enter after copying result file...') - " - -Requirements: - - Qwen/Qwen1.5-MoE-A2.7B model - - sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest LoRA adapter - - Sufficient GPU memory (tests run in separate sessions) - - Uses same memory configuration as launch.json for compatibility -""" - -import json -import logging -import multiprocessing as mp -import os -import pickle -import psutil -import random -import tempfile -import time -import torch -import unittest -from pathlib import Path - -from utils import LoRAModelCase, LoRAAdaptor, ensure_reproducibility - -from sglang.test.runners import HFRunner, SRTRunner -from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci - -# Test prompts for MoE LoRA comparison -TEST_MOE_PROMPTS = [ - "The capital of France is Paris. The capital of", - "Explain what mixture of experts means in machine learning.", - "Write a short poem about artificial intelligence and large language models.", - "What are the benefits of using MoE architectures in transformers?", -] - -# MoE model test cases with LoRA adapters -MOE_LORA_TEST_CASES = [ - LoRAModelCase( - base="Qwen/Qwen1.5-MoE-A2.7B", - adaptors=[ - # Use a real LoRA adapter path - replace with actual path when testing - LoRAAdaptor( - name="sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest", - prefill_tolerance=1e-1, - decode_tolerance=1e-1, - rouge_l_tolerance=1.0 - ), - ], - tp_size=1, - prefill_tolerance=1e-1, - decode_tolerance=1e-1, - rouge_l_tolerance=1.0, - max_loras_per_batch=1, - ), - # Add more MoE models here when available -] - - -class TestMoELoRA(CustomTestCase): - """Test MoE LoRA implementation by comparing against HuggingFace.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Set up detailed logging - logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) - self.logger = logging.getLogger(__name__) - - def _log_system_info(self): - """Log system and GPU memory information.""" - try: - # CPU memory - memory = psutil.virtual_memory() - self.logger.info(f"System Memory: {memory.available / (1024**3):.2f}GB available / {memory.total / (1024**3):.2f}GB total") - - # GPU memory if available - if torch.cuda.is_available(): - gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) - allocated = torch.cuda.memory_allocated(0) / (1024**3) - reserved = torch.cuda.memory_reserved(0) / (1024**3) - self.logger.info(f"GPU Memory: {gpu_memory:.2f}GB total, {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") - else: - self.logger.warning("CUDA not available") - except Exception as e: - self.logger.warning(f"Could not get system info: {e}") - - def _get_srt_server_args(self): - """Get SGLang server arguments from launch.json config.""" - return { - "quantization": "fp8", - "disable_radix_cache": True, - "lora_backend": "csgmv", - "max_lora_chunk_size": 16, - "port": 30000, - "host": "127.0.0.1", - "max_loras_per_batch": 1, - "tp_size": 2, - "max_total_tokens": 128, - "page_size": 64, - "max_running_requests": 1, - "mem_fraction_static": 0.85, - } - - def _run_srt_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): - """Run SGLang test and save results to file.""" - try: - self.logger.info("=== Starting SGLang MoE LoRA test ===") - self._log_system_info() - - base_path = model_case.base - adaptor_names = [adaptor.name for adaptor in model_case.adaptors] - - self.logger.info(f"Model: {base_path}") - self.logger.info(f"LoRA adapters: {adaptor_names}") - self.logger.info(f"dtype: {torch_dtype}") - - server_args = self._get_srt_server_args() - server_args["model_path"] = base_path - server_args["lora_paths"] = adaptor_names - server_args["enable_lora"] = True - - self.logger.info(f"Server args: {server_args}") - self.logger.info("Creating SRTRunner...") - - # Initialize SGLang runner with launch.json args - self.logger.info("Initializing SRTRunner...") - srt_runner = SRTRunner( - model_path=base_path, - torch_dtype=torch_dtype, - model_type="generation", - tp_size=server_args["tp_size"], - lora_paths=adaptor_names, - max_loras_per_batch=server_args["max_loras_per_batch"], - quantization=server_args["quantization"], - max_lora_chunk_size=server_args["max_lora_chunk_size"], - max_running_requests=server_args["max_running_requests"], - lora_backend=server_args["lora_backend"], - disable_cuda_graph=False, - disable_radix_cache=server_args["disable_radix_cache"], - max_total_tokens=server_args["max_total_tokens"], - page_size=server_args["page_size"], - mem_fraction_static=server_args["mem_fraction_static"], - sleep_on_idle=True, - ) - self.logger.info("SRTRunner created successfully") - - results = {} - self._log_system_info() # Check memory after runner creation - - # Test with different batch configurations - test_configs = [ - {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA - {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA - ] - - self.logger.info("Entering SRTRunner context manager (this loads the model)...") - with srt_runner: - self.logger.info("SRTRunner context entered - model should be loaded") - self._log_system_info() # Check memory after model loading - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - - # Use fixed prompts for reproducibility - prompts = TEST_MOE_PROMPTS[:batch_size] - - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - self.logger.info(f"=== Testing config: {config_key} ===") - self.logger.info(f"Batch size: {batch_size}") - self.logger.info(f"LoRA paths: {lora_paths}") - self.logger.info(f"Prompts: {prompts}") - - # Ensure reproducibility - self.logger.info("Ensuring reproducibility...") - ensure_reproducibility() - - # Run SGLang - self.logger.info("Running batch_forward...") - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - self.logger.info("batch_forward completed successfully") - - # Print responses - self.logger.info("=== SRT Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): - self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - self.logger.info(f"SRT Output {i+1}: {output}") - self.logger.info("") - - results[config_key] = { - "prompts": prompts, - "lora_paths": lora_paths, - "srt_outputs": { - "output_strs": srt_outputs.output_strs, - "top_input_logprobs": srt_outputs.top_input_logprobs, - "top_output_logprobs": srt_outputs.top_output_logprobs, - } - } - self.logger.info(f"Results saved for config: {config_key}") - - # Save results - self.logger.info(f"Saving results to {result_file}") - with open(result_file, 'wb') as f: - pickle.dump(results, f) - self.logger.info(f"SGLang results saved successfully to {result_file}") - - # Force GPU memory cleanup after context manager exit - self.logger.info("Forcing GPU memory cleanup...") - import torch - torch.cuda.empty_cache() - torch.cuda.synchronize() - self._log_system_info() # Log memory after cleanup - - except Exception as e: - self.logger.error(f"SGLang test failed: {e}") - import traceback - self.logger.error(f"Full traceback: {traceback.format_exc()}") - with open(result_file, 'wb') as f: - pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) - - def _run_hf_test(self, result_file, model_case, torch_dtype, max_new_tokens=32): - """Run HuggingFace test and save results to file.""" - try: - self.logger.info("=== Starting HF MoE LoRA test ===") - self._log_system_info() - - base_path = model_case.base - adaptor_names = [adaptor.name for adaptor in model_case.adaptors] - - self.logger.info(f"Model: {base_path}") - self.logger.info(f"LoRA adapters: {adaptor_names}") - self.logger.info(f"dtype: {torch_dtype}") - - # Get server args for consistency - server_args = self._get_srt_server_args() - - # Initialize HF runner - self.logger.info("Creating HFRunner...") - hf_runner = HFRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - trust_remote_code=True, # Match SRTRunner behavior - quantization=server_args["quantization"], # Enable quantization if specified - device_map="auto", # Distribute across available GPUs (like tp_size for SRT) - ) - self.logger.info("HFRunner created successfully") - self._log_system_info() # Check memory after runner creation - - results = {} - - # Test with different batch configurations - test_configs = [ - {"batch_size": 1, "lora_paths": [adaptor_names[0]]}, # Single request, single LoRA - {"batch_size": 2, "lora_paths": [adaptor_names[0], adaptor_names[0]]}, # Multiple requests, same LoRA - ] - - self.logger.info("Entering HFRunner context manager (this loads the model)...") - with hf_runner: - self.logger.info("HFRunner context entered - model should be loaded") - self._log_system_info() # Check memory after model loading - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - - # Use fixed prompts for reproducibility - prompts = TEST_MOE_PROMPTS[:batch_size] - - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - self.logger.info(f"=== HF Testing config: {config_key} ===") - self.logger.info(f"Batch size: {batch_size}") - self.logger.info(f"LoRA paths: {lora_paths}") - self.logger.info(f"Prompts: {prompts}") - - # Ensure reproducibility - self.logger.info("Ensuring reproducibility...") - ensure_reproducibility() - - # Run HuggingFace - self.logger.info("Running HF forward...") - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=lora_paths, - ) - self.logger.info("HF forward completed successfully") - - # Print responses - self.logger.info("=== HF Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): - self.logger.info(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - self.logger.info(f"HF Output {i+1}: {output}") - self.logger.info("") - - results[config_key] = { - "prompts": prompts, - "lora_paths": lora_paths, - "hf_outputs": { - "output_strs": hf_outputs.output_strs, - "top_input_logprobs": hf_outputs.top_input_logprobs, - "top_output_logprobs": hf_outputs.top_output_logprobs, - } - } - self.logger.info(f"HF results saved for config: {config_key}") - - # Save results - self.logger.info(f"Saving HF results to {result_file}") - with open(result_file, 'wb') as f: - pickle.dump(results, f) - self.logger.info(f"HF results saved successfully to {result_file}") - - # Force GPU memory cleanup after context manager exit - self.logger.info("Forcing GPU memory cleanup...") - import torch - torch.cuda.empty_cache() - torch.cuda.synchronize() - self._log_system_info() # Log memory after cleanup - - except Exception as e: - self.logger.error(f"HF test failed: {e}") - import traceback - self.logger.error(f"Full traceback: {traceback.format_exc()}") - with open(result_file, 'wb') as f: - pickle.dump({"error": str(e), "traceback": traceback.format_exc()}, f) - - def _run_moe_lora_comparison(self, model_case: LoRAModelCase, torch_dtype, max_new_tokens=32): - """Run LoRA comparison test by loading saved results from separate sessions.""" - self.logger.info("=== Starting MoE LoRA comparison test ===") - self._log_system_info() - - base_path = model_case.base - self.logger.info(f"Model: {base_path}, dtype: {torch_dtype}") - - # Create temp directory for results - with tempfile.TemporaryDirectory() as temp_dir: - srt_result_file = os.path.join(temp_dir, 'srt_results.pkl') - hf_result_file = os.path.join(temp_dir, 'hf_results.pkl') - - self.logger.info(f"Results directory: {temp_dir}") - self.logger.info(f"SRT result file: {srt_result_file}") - self.logger.info(f"HF result file: {hf_result_file}") - - # Check if results already exist - srt_done = os.path.exists(srt_result_file) - hf_done = os.path.exists(hf_result_file) - - self.logger.info(f"SRT results exist: {srt_done}") - self.logger.info(f"HF results exist: {hf_done}") - - if not srt_done: - self.logger.info("Running SGLang test...") - self._run_srt_test(srt_result_file, model_case, torch_dtype, max_new_tokens) - srt_done = True - - if not hf_done: - self.logger.info("Running HuggingFace test...") - self._run_hf_test(hf_result_file, model_case, torch_dtype, max_new_tokens) - hf_done = True - - # Load results - self.logger.info("Loading SRT results...") - if srt_done: - with open(srt_result_file, 'rb') as f: - srt_results = pickle.load(f) - if "error" in srt_results: - error_msg = f"SGLang test failed: {srt_results['error']}" - if "traceback" in srt_results: - error_msg += f"\nTraceback: {srt_results['traceback']}" - self.fail(error_msg) - self.logger.info("SRT results loaded successfully") - - self.logger.info("Loading HF results...") - if hf_done: - with open(hf_result_file, 'rb') as f: - hf_results = pickle.load(f) - if "error" in hf_results: - error_msg = f"HF test failed: {hf_results['error']}" - if "traceback" in hf_results: - error_msg += f"\nTraceback: {hf_results['traceback']}" - self.fail(error_msg) - self.logger.info("HF results loaded successfully") - - # Compare results - if srt_done and hf_done: - self.logger.info("Starting result comparison...") - for config_key in srt_results.keys(): - if config_key in hf_results: - self.logger.info(f"Comparing config: {config_key}") - srt_data = srt_results[config_key] - hf_data = hf_results[config_key] - - self._compare_outputs( - srt_data["srt_outputs"], hf_data["hf_outputs"], - model_case, srt_data["prompts"], srt_data["lora_paths"] - ) - else: - self.logger.warning(f"No HF results for config: {config_key}") - - self.logger.info("Comparison completed successfully") - - def _compare_outputs(self, srt_outputs, hf_outputs, model_case, prompts, lora_paths): - """Compare SGLang and HF outputs.""" - for i, (prompt, lora_path) in enumerate(zip(prompts, lora_paths)): - print(f"\nRequest {i}: lora_path='{lora_path}'") - print(f"Prompt: {prompt[:50]}{'...' if len(prompt) > 50 else ''}") - - # Compare output strings - srt_output_str = srt_outputs.output_strs[i].strip() - hf_output_str = hf_outputs.output_strs[i].strip() - - print(f"SRT output: {srt_output_str[:100]}{'...' if len(srt_output_str) > 100 else ''}") - print(f"HF output: {hf_output_str[:100]}{'...' if len(hf_output_str) > 100 else ''}") - - # Calculate ROUGE-L similarity - rouge_l = calculate_rouge_l(srt_output_str, hf_output_str) - print(f"ROUGE-L similarity: {rouge_l:.4f}") - - # Check ROUGE-L tolerance - self.assertGreaterEqual( - rouge_l, - model_case.rouge_l_tolerance, - f"ROUGE-L similarity {rouge_l:.4f} below tolerance {model_case.rouge_l_tolerance} " - f"for request {i} with lora_path='{lora_path}'" - ) - - # Compare logprobs if available - if hasattr(srt_outputs, 'top_input_logprobs') and hasattr(hf_outputs, 'top_input_logprobs'): - if srt_outputs.top_input_logprobs[i] is not None and hf_outputs.top_input_logprobs[i] is not None: - import torch - srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i]) - hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i]) - - max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill)) - print(f"Max prefill logprob diff: {max_prefill_diff:.6f}") - - # Check prefill tolerance - prefill_tol = model_case.prefill_tolerance - self.assertLessEqual( - max_prefill_diff, - prefill_tol, - f"Prefill logprob diff {max_prefill_diff:.6f} exceeds tolerance {prefill_tol} " - f"for request {i} with lora_path='{lora_path}'" - ) - - if hasattr(srt_outputs, 'top_output_logprobs') and hasattr(hf_outputs, 'top_output_logprobs'): - if srt_outputs.top_output_logprobs[i] is not None and hf_outputs.top_output_logprobs[i] is not None: - import torch - srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i]) - hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i]) - - max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode)) - print(f"Max decode logprob diff: {max_decode_diff:.6f}") - - # Check decode tolerance - decode_tol = model_case.decode_tolerance - self.assertLessEqual( - max_decode_diff, - decode_tol, - f"Decode logprob diff {max_decode_diff:.6f} exceeds tolerance {decode_tol} " - f"for request {i} with lora_path='{lora_path}'" - ) - - def test_moe_lora_qwen15(self): - """Test LoRA on Qwen1.5-MoE-A2.7B.""" - self.logger.info("=== Starting test_moe_lora_qwen15 ===") - - if is_in_ci(): - self.logger.info("Skipping MoE LoRA test in CI environment") - self.skipTest("Skipping MoE LoRA test in CI environment") - - model_case = MOE_LORA_TEST_CASES[0] - self.logger.info(f"Using model case: {model_case.base}") - - # Test with different dtypes - import torch - for torch_dtype in [torch.float16, torch.bfloat16]: - self.logger.info(f"Testing dtype: {torch_dtype}") - with self.subTest(dtype=torch_dtype): - try: - self._run_moe_lora_comparison(model_case, torch_dtype) - except Exception as e: - self.logger.error(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") - import traceback - self.logger.error(f"Traceback: {traceback.format_exc()}") - self.fail(f"Test failed for {model_case.base} with dtype {torch_dtype}: {e}") - - def test_moe_lora_basic_functionality(self): - """Basic functionality test for MoE LoRA dispatch.""" - # This test focuses on the core dispatch logic without full HF comparison - # Useful for debugging the MoE LoRA implementation - - import torch - from sglang.srt.lora.moe_dispatch import moe_dispatch - - # Create test data - num_tokens = 4 - top_k = 2 - num_experts = 8 - - # Mock top-k routing results - topk_ids = torch.tensor([ - [0, 1], # token 0 routes to experts 0, 1 - [2, 3], # token 1 routes to experts 2, 3 - [1, 4], # token 2 routes to experts 1, 4 - [5, 6], # token 3 routes to experts 5, 6 - ], dtype=torch.int32) - - topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) - - # Mock LoRA indices (one per token) - lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) # tokens 0,1 use lora 0; tokens 2,3 use lora 1 - - # Run dispatch - token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( - topk_ids=topk_ids, - topk_weights=topk_weights, - lora_indices=lora_indices, - ) - - # Verify results - # Should have 4 tokens * 2 experts each = 8 dispatched entries - self.assertEqual(len(token_ids), 8) - self.assertEqual(len(expert_ids), 8) - self.assertEqual(len(sorted_topk_weights), 8) - self.assertEqual(len(lora_ids), 8) - - # Check that tokens are grouped by expert (not by LoRA) - # All tokens going to expert 0 should come first, then expert 1, etc. - unique_experts, expert_counts = torch.unique_consecutive(expert_ids, return_counts=True) - self.assertTrue(torch.all(expert_counts >= 1)) # Each expert should have at least one token - - print(f"Dispatch successful: {len(token_ids)} dispatched tokens to experts {unique_experts.tolist()}") - - -def debug_full_comparison(): - """Debug helper to run full SRT vs HF comparison.""" - import torch - import tempfile - import os - - # Set up logging for debugging - logging.basicConfig( - level=logging.INFO, # Less verbose for debug script - format='%(asctime)s - %(levelname)s - %(message)s' - ) - - test = TestMoELoRA() - model_case = MOE_LORA_TEST_CASES[0] - - print("=" * 80) - print("DEBUG: Running Full SRT vs HF Comparison") - print("=" * 80) - - srt_results = {} - hf_results = {} - - # Test configurations - test_configs = [ - {"batch_size": 1, "lora_paths": [model_case.adaptors[0].name]}, # Single request, single LoRA - {"batch_size": 2, "lora_paths": [model_case.adaptors[0].name, model_case.adaptors[0].name]}, # Multiple requests, same LoRA - ] - - try: - # Phase 1: Run SRT tests - print("\n" + "="*50) - print("PHASE 1: Running SGLang (SRT) Tests") - print("="*50) - - server_args = test._get_srt_server_args() - print(f"Server args: quantization={server_args['quantization']}, tp_size={server_args['tp_size']}, mem_fraction_static={server_args['mem_fraction_static']}") - - srt_runner = SRTRunner( - model_path=model_case.base, - torch_dtype=torch.float16, - model_type="generation", - tp_size=server_args["tp_size"], - lora_paths=[adaptor.name for adaptor in model_case.adaptors], - max_loras_per_batch=server_args["max_loras_per_batch"], - quantization=server_args["quantization"], - max_lora_chunk_size=server_args["max_lora_chunk_size"], - max_running_requests=server_args["max_running_requests"], - lora_backend=server_args["lora_backend"], - disable_cuda_graph=False, - disable_radix_cache=server_args["disable_radix_cache"], - max_total_tokens=server_args["max_total_tokens"], - page_size=server_args["page_size"], - mem_fraction_static=server_args["mem_fraction_static"], - sleep_on_idle=True, - ) - print("✓ SRT runner created successfully") - - with srt_runner: - print("✓ SRT model loaded successfully") - test._log_system_info() - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - prompts = TEST_MOE_PROMPTS[:batch_size] - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - print(f"\n--- SRT Testing {config_key} ---") - print(f"Prompts: {prompts}") - - # Ensure reproducibility - ensure_reproducibility() - - srt_outputs = srt_runner.batch_forward( - prompts, - max_new_tokens=32, - lora_paths=lora_paths, - ) - print("✓ SRT batch_forward completed") - - # Print SRT responses - print("=== SRT Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, srt_outputs.output_strs)): - print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - print(f"SRT Output {i+1}: {output}") - print("") - - srt_results[config_key] = { - "prompts": prompts, - "lora_paths": lora_paths, - "srt_outputs": srt_outputs - } - - print("✓ All SRT tests completed successfully") - - # Clear GPU memory before HF test - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - print("✓ GPU cache cleared before HF test") - - # Small delay to ensure memory is fully released - import time - time.sleep(2) - test._log_system_info() - - # Phase 2: Run HF tests - print("\n" + "="*50) - print("PHASE 2: Running HuggingFace (HF) Tests") - print("="*50) - - server_args = test._get_srt_server_args() - hf_runner = HFRunner( - model_case.base, - torch_dtype=torch.float16, - model_type="generation", - trust_remote_code=True, # Match SRTRunner behavior - quantization=server_args["quantization"], # Enable quantization if specified - device_map="auto", # Distribute across available GPUs - ) - print("✓ HF runner created successfully") - - with hf_runner: - print("✓ HF model loaded successfully") - test._log_system_info() - - for config in test_configs: - batch_size = config["batch_size"] - lora_paths = config["lora_paths"] - prompts = TEST_MOE_PROMPTS[:batch_size] - config_key = f"batch_{batch_size}_lora_{len(set(lora_paths))}" - - print(f"\n--- HF Testing {config_key} ---") - print(f"Prompts: {prompts}") - - # Ensure reproducibility - ensure_reproducibility() - - hf_outputs = hf_runner.forward( - prompts, - max_new_tokens=32, - lora_paths=lora_paths, - ) - print("✓ HF forward completed") - - # Print HF responses - print("=== HF Generated Responses ===") - for i, (prompt, output) in enumerate(zip(prompts, hf_outputs.output_strs)): - print(f"Prompt {i+1}: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") - print(f"HF Output {i+1}: {output}") - print("") - - hf_results[config_key] = { - "prompts": prompts, - "lora_paths": lora_paths, - "hf_outputs": hf_outputs - } - - print("✓ All HF tests completed successfully") - - # Force final GPU memory cleanup - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - print("✓ Final GPU memory cleanup completed") - test._log_system_info() - - # Phase 3: Compare results - print("\n" + "="*50) - print("PHASE 3: Comparing SRT vs HF Outputs") - print("="*50) - - for config_key in srt_results.keys(): - if config_key in hf_results: - srt_data = srt_results[config_key] - hf_data = hf_results[config_key] - - print(f"\n{'='*30} Comparing {config_key} {'='*30}") - try: - test._compare_outputs( - srt_data["srt_outputs"], hf_data["hf_outputs"], - model_case, srt_data["prompts"], srt_data["lora_paths"] - ) - print(f"✓ Comparison passed for {config_key}") - except AssertionError as e: - print(f"✗ Comparison failed for {config_key}: {e}") - raise - else: - print(f"✗ No HF results for config: {config_key}") - - print("\n" + "="*80) - print("🎉 ALL COMPARISONS PASSED! SRT and HF outputs match!") - print("="*80) - - except Exception as e: - print(f"\n💥 DEBUG SCRIPT FAILED: {e}") - import traceback - traceback.print_exc() - return False - - return True - - -if __name__ == "__main__": - import sys - - if len(sys.argv) > 1 and sys.argv[1] == "--debug": - # Run debug tests instead of unittest - success = debug_full_comparison() - sys.exit(0 if success else 1) - else: - try: - mp.set_start_method("spawn") - except RuntimeError: - pass - - unittest.main() From 883b5ef224353d7eba90e70810347a96a1040751 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 15 Dec 2025 00:28:08 +0000 Subject: [PATCH 034/150] fix --- python/sglang/srt/lora/mem_pool.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index ea4e56397fb2..c5b085ed910d 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -492,7 +492,8 @@ def load_lora_weight_tensor( expert_match = re.search(r"experts\.(\d+)\.", name) - if expert_match and self.is_moe_module(target_module): + if expert_match: + target_module = target_module + "_moe" # MoE weight - multiple tensors per module (one per expert) if temp_A_buffer[target_module] is None: temp_A_buffer[target_module] = {} @@ -748,4 +749,4 @@ def get_tensor( return buffer_dict[target_module][layer_id] def get_buffer_id(self, lora_uid: str): - return self.uid_to_buffer_id[lora_uid] + return self.uid_to_buffer_id[lora_uid] \ No newline at end of file From b0cd554e41fbff7f9a54dab1d8208f5c741f29ad Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 15 Dec 2025 00:31:44 +0000 Subject: [PATCH 035/150] modify shape initialization to use moe_intermediate_size_ from config file --- python/sglang/srt/lora/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index c6cc51745745..afe01b1bccef 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -78,10 +78,14 @@ def get_hidden_dim( head_dim * config.num_attention_heads, config.hidden_size, ) - elif module_name == "gate_up_proj" or module_name == "gate_up_proj_moe": + elif module_name == "gate_up_proj": return config.hidden_size, config.intermediate_size * 2 - elif module_name == "down_proj" or module_name == "down_proj_moe": + elif module_name == "down_proj": return config.intermediate_size, config.hidden_size + elif module_name == "gate_up_proj_moe": + return config.hidden_size, config.moe_intermediate_size + elif module_name == "down_proj_moe": + return config.moe_intermediate_size, config.hidden_size elif module_name == "embed_tokens": # For embedding: input is vocab_size (as embedding lookup), output is hidden_size # if contain extra tokens will be added; otherwise is 0. From d1e31550ff5fa44d01c6df58d81f2afa5c36ad64 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 15 Dec 2025 22:33:52 -0500 Subject: [PATCH 036/150] fix dim mismatch in buffer_view and weights due to stacking --- python/sglang/srt/lora/mem_pool.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index c5b085ed910d..7327fa52742b 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -163,7 +163,12 @@ def get_lora_A_shape( "num_local_experts", getattr(self.base_hf_config, "num_experts", 0), ) - return (self.max_loras_per_batch, num_experts, max_lora_dim, input_dim) + return ( + self.max_loras_per_batch, + num_experts, + max_lora_dim * c, + input_dim, + ) else: return (self.max_loras_per_batch, max_lora_dim * c, input_dim) @@ -588,7 +593,9 @@ def load_lora_weight_tensor( # MoE: multiple tensors per module (one per expert) for expert_id, expert_weight in weights.items(): # Buffer shape: [num_loras, num_experts, max_rank, hidden_dim] - buffer_view = target_buffer[buffer_id, expert_id, :lora_rank, :] + buffer_view = target_buffer[ + buffer_id, expert_id, : lora_rank * c, : + ] load_lora_weight_tensor(buffer_view, expert_weight) else: # Standard: single tensor per module @@ -745,8 +752,8 @@ def get_tensor( - 4D tensor [num_loras, num_experts, rank, hidden] for MoE modules """ buffer_dict = self.A_buffer if lora_type == LoRAType.LORA_A else self.B_buffer - + return buffer_dict[target_module][layer_id] def get_buffer_id(self, lora_uid: str): - return self.uid_to_buffer_id[lora_uid] \ No newline at end of file + return self.uid_to_buffer_id[lora_uid] From 1eb9df144a267c6494dc8eb825752c4219fb5bfa Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 15 Dec 2025 22:38:45 -0500 Subject: [PATCH 037/150] fix stacking issue for gate_up_proj for LoRA B --- python/sglang/srt/lora/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index afe01b1bccef..c7ae8c3aebeb 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -83,7 +83,7 @@ def get_hidden_dim( elif module_name == "down_proj": return config.intermediate_size, config.hidden_size elif module_name == "gate_up_proj_moe": - return config.hidden_size, config.moe_intermediate_size + return config.hidden_size, config.moe_intermediate_size * 2 elif module_name == "down_proj_moe": return config.moe_intermediate_size, config.hidden_size elif module_name == "embed_tokens": From b50cfe4587f42e096251cfc540e4b6be4f8632c5 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 26 Dec 2025 17:11:54 -0800 Subject: [PATCH 038/150] add down proj calculation --- python/sglang/srt/lora/layers.py | 53 ++++-- python/sglang/srt/lora/lora_manager.py | 44 +++-- .../lora/triton_ops/per_expert_lora_moe.py | 156 +++++++++--------- 3 files changed, 146 insertions(+), 107 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index d146233cf69d..5e1c023e2e1e 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -593,18 +593,24 @@ def __init__( ): super().__init__(base_layer, lora_backend) # LoRA tensors will be set by LoRAManager - self.lora_a_weights = None - self.lora_b_weights = None + self.gate_up_lora_a_weights = None + self.gate_up_lora_b_weights = None + self.down_lora_a_weights = None + self.down_lora_b_weights = None def set_lora_info( self, - lora_a_weights: torch.Tensor, - lora_b_weights: torch.Tensor, + gate_up_lora_a_weights: torch.Tensor, + gate_up_lora_b_weights: torch.Tensor, + down_lora_a_weights: torch.Tensor = None, + down_lora_b_weights: torch.Tensor = None, ): """Set LoRA weight tensors from memory pool.""" self.set_lora = True - self.lora_a_weights = lora_a_weights - self.lora_b_weights = lora_b_weights + self.gate_up_lora_a_weights = gate_up_lora_a_weights + self.gate_up_lora_b_weights = gate_up_lora_b_weights + self.down_lora_a_weights = down_lora_a_weights + self.down_lora_b_weights = down_lora_b_weights def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): """ @@ -619,7 +625,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs base_output = self.base_layer.forward(hidden_states, topk_output, **kwargs) # If LoRA is enabled, compute delta and add in-place for memory efficiency - if self.set_lora and self.lora_a_weights is not None: + if self.set_lora and self.gate_up_lora_a_weights is not None: self._compute_lora_delta(hidden_states, topk_output, base_output) return base_output @@ -653,7 +659,6 @@ def _compute_lora_delta( lora_indices = self.lora_backend.forward_batch.token_lora_indices num_experts = self.base_layer.num_experts - num_loras = self.lora_a_weights.shape[0] # Dispatch tokens to experts token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( @@ -662,22 +667,40 @@ def _compute_lora_delta( lora_indices=lora_indices, ) - - - # Compute per-expert LoRA forward (adds to base_output in-place) - per_expert_lora_forward( + # Apply gate_up_proj LoRA: hidden_states -> intermediate space + gate_up_output = per_expert_lora_forward( hidden_states=hidden_states, - lora_a_weights=self.lora_a_weights, - lora_b_weights=self.lora_b_weights, + lora_a_weights=self.gate_up_lora_a_weights, + lora_b_weights=self.gate_up_lora_b_weights, token_ids=token_ids, expert_ids=expert_ids, lora_ids=lora_ids, lora_ranks=lora_ranks, lora_scalings=scalings, num_experts=num_experts, - base_output=base_output, + base_output=None, + is_down_proj=False, ) + # Apply down_proj LoRA: intermediate space -> hidden space, added to base_output + if ( + self.down_lora_a_weights is not None + and self.down_lora_b_weights is not None + ): + per_expert_lora_forward( + hidden_states=gate_up_output, + lora_a_weights=self.down_lora_a_weights, + lora_b_weights=self.down_lora_b_weights, + token_ids=token_ids, + expert_ids=expert_ids, + lora_ids=lora_ids, + lora_ranks=lora_ranks, + lora_scalings=scalings, + num_experts=num_experts, + base_output=base_output, + is_down_proj=True, + ) + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For MoE layers, tensor parallelism is typically not used # Return weights unchanged diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index bc3697752723..1ecebdefd83b 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -288,7 +288,9 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): num_tokens = forward_batch.input_ids.shape[0] # Tokens in current forward pass # Create tensor and fill with adapter indices from segments - token_lora_indices_reordered = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device) + token_lora_indices_reordered = torch.empty( + num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device + ) seg_indptr = batch_info.seg_indptr # [num_segments + 1] for seg_idx in range(batch_info.num_segments): start_token = seg_indptr[seg_idx] @@ -302,7 +304,11 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): else: # Tokens are reordered (chunked backend): need to convert back to original order inverse_permutation = torch.empty_like(batch_info.permutation) - inverse_permutation[batch_info.permutation] = torch.arange(num_tokens, dtype=batch_info.permutation.dtype, device=batch_info.permutation.device) + inverse_permutation[batch_info.permutation] = torch.arange( + num_tokens, + dtype=batch_info.permutation.dtype, + device=batch_info.permutation.device, + ) token_lora_indices = token_lora_indices_reordered[inverse_permutation] forward_batch.token_lora_indices = token_lora_indices @@ -317,21 +323,33 @@ def update_lora_info(self): for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): # Hack for FusedMoE layer - if isinstance(module, FusedMoEWithLoRA) and all(x in self.target_modules for x in ['gate_up_proj', 'down_proj']): + if isinstance(module, FusedMoEWithLoRA) and all( + x in self.target_modules for x in ["gate_up_proj", "down_proj"] + ): module.set_lora_info( - self.memory_pool.get_tensor( - target_module='gate_up_proj_moe', + gate_up_lora_a_weights=self.memory_pool.get_tensor( + target_module="gate_up_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + ), + gate_up_lora_b_weights=self.memory_pool.get_tensor( + target_module="gate_up_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + ), + down_lora_a_weights=self.memory_pool.get_tensor( + target_module="down_proj_moe", layer_id=layer_id, lora_type=LoRAType.LORA_A, ), - self.memory_pool.get_tensor( - target_module='down_proj_moe', + down_lora_b_weights=self.memory_pool.get_tensor( + target_module="down_proj_moe", layer_id=layer_id, lora_type=LoRAType.LORA_B, ), ) continue - + target_module = get_target_module_name( module_name, self.memory_pool.target_modules ) @@ -377,8 +395,8 @@ def init_state( the target modules and max_lora_rank. """ - assert lora_paths or ( - max_lora_rank is not None and target_modules is not None + assert ( + lora_paths or (max_lora_rank is not None and target_modules is not None) ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." self.init_lora_adapters(lora_paths) @@ -557,9 +575,11 @@ def init_lora_modules(self): module_name, module ) continue - + # Temporarily workaround for FusedMoE layer - if isinstance(module, FusedMoE) and all(x in self.target_modules for x in ['gate_up_proj', 'down_proj']): + if isinstance(module, FusedMoE) and all( + x in self.target_modules for x in ["gate_up_proj", "down_proj"] + ): self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module ) diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 92553f9ed07a..ab258d7182f1 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -22,41 +22,35 @@ @triton.jit def _per_expert_lora_kernel( # Input/Output pointers - hidden_states_ptr, # [num_total_tokens, hidden_dim] - lora_a_weights_ptr, # [num_loras, num_experts, max_rank, hidden_dim] - lora_b_weights_ptr, # [num_loras, num_experts, intermediate_dim, max_rank] - output_ptr, # [num_total_tokens, intermediate_dim] - + hidden_states_ptr, # [num_total_tokens, input_dim] + lora_a_weights_ptr, # [num_loras, num_experts, max_rank, input_dim] + lora_b_weights_ptr, # [num_loras, num_experts, output_dim, max_rank] + output_ptr, # [num_total_tokens, output_dim] # Dispatch info (length = num_dispatched) - token_ids_ptr, # [num_dispatched] -> index into hidden/output - expert_ids_ptr, # [num_dispatched] - lora_ids_ptr, # [num_dispatched] - + token_ids_ptr, # [num_dispatched] -> index into hidden/output + expert_ids_ptr, # [num_dispatched] + lora_ids_ptr, # [num_dispatched] # Dimensions - hidden_dim: tl.constexpr, - intermediate_dim: tl.constexpr, + input_dim: tl.constexpr, + output_dim: tl.constexpr, max_rank: tl.constexpr, num_experts: tl.constexpr, num_dispatched, - - # Strides for 4D LoRA A weights [num_loras, num_experts, max_rank, hidden_dim] + # Strides for 4D LoRA A weights [num_loras, num_experts, max_rank, input_dim] lora_a_stride_lora: tl.constexpr, lora_a_stride_expert: tl.constexpr, lora_a_stride_rank: tl.constexpr, - lora_a_stride_hidden: tl.constexpr, - - # Strides for 4D LoRA B weights [num_loras, num_experts, intermediate_dim, max_rank] + lora_a_stride_input: tl.constexpr, + # Strides for 4D LoRA B weights [num_loras, num_experts, output_dim, max_rank] lora_b_stride_lora: tl.constexpr, lora_b_stride_expert: tl.constexpr, - lora_b_stride_intermediate: tl.constexpr, + lora_b_stride_output: tl.constexpr, lora_b_stride_rank: tl.constexpr, - # LoRA ranks per adapter [num_loras] lora_ranks_ptr, # Scaling factors per adapter [num_loras] lora_scalings_ptr, - - # Block size (used for hidden and output tiling; rank is not tiled) + # Block size (used for input and output tiling; rank is not tiled) BLOCK_SIZE: tl.constexpr, ): """ @@ -66,14 +60,14 @@ def _per_expert_lora_kernel( 3D Grid: (spatial, slices, loras) - spatial = program_id(0): dispatched token index - - slices = program_id(1): tile index along intermediate_dim + - slices = program_id(1): tile index along output_dim - loras = program_id(2): LoRA adapter index """ # 3D grid indices - spatial_id = tl.program_id(0) # dispatched token index - slice_id = tl.program_id(1) # output slice index - lora_id_grid = tl.program_id(2) # LoRA adapter index + spatial_id = tl.program_id(0) # dispatched token index + slice_id = tl.program_id(1) # output slice index + lora_id_grid = tl.program_id(2) # LoRA adapter index # Bounds check on dispatched tokens if spatial_id >= num_dispatched: @@ -99,7 +93,7 @@ def _per_expert_lora_kernel( # Base pointers # ---------------------------- # hidden_states[actual_token_id, :] - hidden_ptr = hidden_states_ptr + actual_token_id * hidden_dim + hidden_ptr = hidden_states_ptr + actual_token_id * input_dim # A[lora_id_grid, expert_id, :, :] lora_a_base = ( @@ -120,38 +114,38 @@ def _per_expert_lora_kernel( # ---------------------------- # We assume max_rank is small enough to keep as a single 1D vector - r_offs = tl.arange(0, max_rank) # [max_rank] - rank_mask = r_offs < rank # [max_rank] + r_offs = tl.arange(0, max_rank) # [max_rank] + rank_mask = r_offs < rank # [max_rank] # Accumulator for intermediate: [max_rank] intermediate = tl.zeros((max_rank,), dtype=tl.float32) - # Tile over hidden_dim in chunks of BLOCK_SIZE - NUM_HIDDEN_TILES = (hidden_dim + BLOCK_SIZE - 1) // BLOCK_SIZE - for hidden_tile_idx in range(NUM_HIDDEN_TILES): - hidden_start = hidden_tile_idx * BLOCK_SIZE - hidden_offs = hidden_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] - hidden_mask = hidden_offs < hidden_dim # [BLOCK_SIZE] + # Tile over input_dim in chunks of BLOCK_SIZE + NUM_INPUT_TILES = (input_dim + BLOCK_SIZE - 1) // BLOCK_SIZE + for input_tile_idx in range(NUM_INPUT_TILES): + input_start = input_tile_idx * BLOCK_SIZE + input_offs = input_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] + input_mask = input_offs < input_dim # [BLOCK_SIZE] - # Load hidden values for this tile: [BLOCK_SIZE] + # Load input values for this tile: [BLOCK_SIZE] h_vals = tl.load( - hidden_ptr + hidden_offs, - mask=hidden_mask, + hidden_ptr + input_offs, + mask=input_mask, other=0.0, ).to(tl.float32) # Build [max_rank, BLOCK_SIZE] tile of A: # rows: r_offs - # cols: hidden_offs - # offset = base + r * stride_rank + h * stride_hidden + # cols: input_offs + # offset = base + r * stride_rank + h * stride_input a_ptrs = ( lora_a_base + r_offs[:, None] * lora_a_stride_rank - + hidden_offs[None, :] * lora_a_stride_hidden + + input_offs[None, :] * lora_a_stride_input ) a_vals = tl.load( a_ptrs, - mask=rank_mask[:, None] & hidden_mask[None, :], + mask=rank_mask[:, None] & input_mask[None, :], other=0.0, ).to(tl.float32) @@ -161,11 +155,11 @@ def _per_expert_lora_kernel( # ---------------------------- # Stage 2: y_slice = B[out_slice, :] @ intermediate - # One output slice per program along intermediate_dim. + # One output slice per program along output_dim. # ---------------------------- out_start = slice_id * BLOCK_SIZE - out_offs = out_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] - out_mask = out_offs < intermediate_dim # [BLOCK_SIZE] + out_offs = out_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] + out_mask = out_offs < output_dim # [BLOCK_SIZE] # If this slice is entirely out of bounds, we can early-exit # (not strictly necessary but cheap) @@ -176,10 +170,10 @@ def _per_expert_lora_kernel( # Build [max_rank, BLOCK_SIZE] tile of B: # rows: r_offs (rank dimension) # cols: out_offs (output dimension) - # offset = base + out * stride_intermediate + r * stride_rank + # offset = base + out * stride_output + r * stride_rank b_ptrs = ( lora_b_base - + out_offs[None, :] * lora_b_stride_intermediate + + out_offs[None, :] * lora_b_stride_output + r_offs[:, None] * lora_b_stride_rank ) b_vals = tl.load( @@ -190,7 +184,7 @@ def _per_expert_lora_kernel( # Contribution: # out_vals[j] = sum_r B[j, r] * intermediate[r] - out_vals = tl.sum(b_vals * intermediate[:, None], axis=0) # [BLOCK_SIZE] + out_vals = tl.sum(b_vals * intermediate[:, None], axis=0) # [BLOCK_SIZE] # Apply scaling out_vals *= scaling @@ -198,7 +192,7 @@ def _per_expert_lora_kernel( # ---------------------------- # Accumulate into global output # ---------------------------- - out_row_base = actual_token_id * intermediate_dim + out_row_base = actual_token_id * output_dim out_ptrs = output_ptr + out_row_base + out_offs tl.atomic_add( @@ -207,6 +201,7 @@ def _per_expert_lora_kernel( mask=out_mask & has_rank, ) + def per_expert_lora_forward( hidden_states: torch.Tensor, lora_a_weights: torch.Tensor, @@ -218,29 +213,33 @@ def per_expert_lora_forward( lora_scalings: torch.Tensor, num_experts: int, base_output: torch.Tensor = None, + is_down_proj: bool = False, ) -> torch.Tensor: """ Forward pass for per-expert LoRA computation using a 3D Triton grid: grid = (spatial, slices, loras) Args: - hidden_states: [num_tokens, hidden_dim] - lora_a_weights: [num_loras, num_experts, max_rank, hidden_dim] - lora_b_weights: [num_loras, num_experts, intermediate_dim, max_rank] + hidden_states: [num_tokens, input_dim] where input_dim is hidden_dim for gate_up_proj + or intermediate_dim for down_proj + lora_a_weights: [num_loras, num_experts, max_rank, input_dim] + lora_b_weights: [num_loras, num_experts, output_dim, max_rank] token_ids: [num_dispatched] - Original token indices expert_ids: [num_dispatched] - Expert ID for each dispatched token lora_ids: [num_dispatched] - LoRA ID for each dispatched token lora_ranks: [num_loras] - Rank for each LoRA lora_scalings: [num_loras] - Scaling factor for each LoRA num_experts: Total number of experts - base_output: [num_tokens, intermediate_dim] - Base MoE output (modified in-place) + base_output: [num_tokens, output_dim] - Base MoE output (modified in-place) + is_down_proj: Whether this is for down_proj (intermediate_dim -> hidden_dim) + or gate_up_proj (hidden_dim -> intermediate_dim) Returns: - output: [num_tokens, intermediate_dim] - Base output + LoRA delta (in-place) + output: [num_tokens, output_dim] - Base output + LoRA delta (in-place) """ # Shapes - num_tokens, hidden_dim = hidden_states.shape - num_loras, _, intermediate_dim, max_rank = lora_b_weights.shape + num_tokens, input_dim = hidden_states.shape + num_loras, _, output_dim, max_rank = lora_b_weights.shape num_dispatched = token_ids.shape[0] # Make sure everything is on the same device and contiguous @@ -259,59 +258,56 @@ def per_expert_lora_forward( # Use float32 for accumulation; you can cast back if needed output = torch.zeros( num_tokens, - hidden_dim, + output_dim, dtype=torch.float32, device=device, ) else: output = base_output - assert output.shape == (num_tokens, hidden_dim) # TODO (jonahcb): check if this is correct + assert output.shape == ( + num_tokens, + output_dim, + ), f"Expected shape ({num_tokens}, {output_dim}), got {output.shape}" assert output.device == device # Tile size for hidden and output dimensions BLOCK_SIZE = 64 # tune as needed - # Number of output slices along intermediate_dim - num_slices = (intermediate_dim + BLOCK_SIZE - 1) // BLOCK_SIZE + # Number of output slices along output_dim + num_slices = (output_dim + BLOCK_SIZE - 1) // BLOCK_SIZE # 3D grid: (spatial, slices, loras) grid = (num_dispatched, num_slices, num_loras) _per_expert_lora_kernel[grid]( # Pointers - hidden_states, # hidden_states_ptr - lora_a_weights, # lora_a_weights_ptr - lora_b_weights, # lora_b_weights_ptr - output, # output_ptr - + hidden_states, # hidden_states_ptr + lora_a_weights, # lora_a_weights_ptr + lora_b_weights, # lora_b_weights_ptr + output, # output_ptr # Dispatch info - token_ids, # token_ids_ptr - expert_ids, # expert_ids_ptr - lora_ids, # lora_ids_ptr - + token_ids, # token_ids_ptr + expert_ids, # expert_ids_ptr + lora_ids, # lora_ids_ptr # Dimensions - hidden_dim, # hidden_dim - intermediate_dim, # intermediate_dim - max_rank, # max_rank - num_experts, # num_experts - num_dispatched, # num_dispatched (runtime scalar) - - # LoRA A strides: [num_loras, num_experts, max_rank, hidden_dim] + input_dim, # input_dim (hidden_dim for gate_up_proj, intermediate_dim for down_proj) + output_dim, # output_dim (intermediate_dim for gate_up_proj, hidden_dim for down_proj) + max_rank, # max_rank + num_experts, # num_experts + num_dispatched, # num_dispatched (runtime scalar) + # LoRA A strides: [num_loras, num_experts, max_rank, input_dim] lora_a_weights.stride(0), lora_a_weights.stride(1), lora_a_weights.stride(2), lora_a_weights.stride(3), - - # LoRA B strides: [num_loras, num_experts, intermediate_dim, max_rank] + # LoRA B strides: [num_loras, num_experts, output_dim, max_rank] lora_b_weights.stride(0), lora_b_weights.stride(1), lora_b_weights.stride(2), lora_b_weights.stride(3), - # Rank & scaling - lora_ranks, # lora_ranks_ptr - lora_scalings, # lora_scalings_ptr - + lora_ranks, # lora_ranks_ptr + lora_scalings, # lora_scalings_ptr # Block size (constexpr) BLOCK_SIZE=BLOCK_SIZE, ) From 28daf591259a6ad714c6b8d05345be8b89774a1c Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 26 Dec 2025 17:41:20 -0800 Subject: [PATCH 039/150] add debugging statements --- python/sglang/srt/lora/layers.py | 20 ++++++ python/sglang/srt/lora/lora_manager.py | 67 +++++++++++++------ .../lora/triton_ops/per_expert_lora_moe.py | 26 +++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 5e1c023e2e1e..f521672ead19 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -682,6 +682,17 @@ def _compute_lora_delta( is_down_proj=False, ) + # DEBUG: Check for NaNs immediately after gate_up LoRA + if torch.isnan(gate_up_output).any(): + print(f"NaNs detected in gate_up_output! Shape: {gate_up_output.shape}") + print( + f"gate_up_output min/max: {gate_up_output.min()}, {gate_up_output.max()}" + ) + print(f"Input hidden_states shape: {hidden_states.shape}") + import pdb + + pdb.set_trace() + # Apply down_proj LoRA: intermediate space -> hidden space, added to base_output if ( self.down_lora_a_weights is not None @@ -701,6 +712,15 @@ def _compute_lora_delta( is_down_proj=True, ) + # DEBUG: Check for NaNs after down_proj LoRA + if torch.isnan(base_output).any(): + print(f"NaNs detected in base_output after down_proj LoRA!") + print(f"base_output min/max: {base_output.min()}, {base_output.max()}") + print(f"gate_up_output shape: {gate_up_output.shape}") + import pdb + + pdb.set_trace() + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For MoE layers, tensor parallelism is typically not used # Return weights unchanged diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 1ecebdefd83b..058a04f665ae 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -326,27 +326,54 @@ def update_lora_info(self): if isinstance(module, FusedMoEWithLoRA) and all( x in self.target_modules for x in ["gate_up_proj", "down_proj"] ): + gate_up_a = self.memory_pool.get_tensor( + target_module="gate_up_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + ) + gate_up_b = self.memory_pool.get_tensor( + target_module="gate_up_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + ) + down_a = self.memory_pool.get_tensor( + target_module="down_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + ) + down_b = self.memory_pool.get_tensor( + target_module="down_proj_moe", + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + ) + + # DEBUG: Check LoRA weights for NaNs + if gate_up_a is not None and torch.isnan(gate_up_a).any(): + print(f"NaNs in gate_up LoRA_A weights!") + import pdb + + pdb.set_trace() + if gate_up_b is not None and torch.isnan(gate_up_b).any(): + print(f"NaNs in gate_up LoRA_B weights!") + import pdb + + pdb.set_trace() + if down_a is not None and torch.isnan(down_a).any(): + print(f"NaNs in down LoRA_A weights!") + import pdb + + pdb.set_trace() + if down_b is not None and torch.isnan(down_b).any(): + print(f"NaNs in down LoRA_B weights!") + import pdb + + pdb.set_trace() + module.set_lora_info( - gate_up_lora_a_weights=self.memory_pool.get_tensor( - target_module="gate_up_proj_moe", - layer_id=layer_id, - lora_type=LoRAType.LORA_A, - ), - gate_up_lora_b_weights=self.memory_pool.get_tensor( - target_module="gate_up_proj_moe", - layer_id=layer_id, - lora_type=LoRAType.LORA_B, - ), - down_lora_a_weights=self.memory_pool.get_tensor( - target_module="down_proj_moe", - layer_id=layer_id, - lora_type=LoRAType.LORA_A, - ), - down_lora_b_weights=self.memory_pool.get_tensor( - target_module="down_proj_moe", - layer_id=layer_id, - lora_type=LoRAType.LORA_B, - ), + gate_up_lora_a_weights=gate_up_a, + gate_up_lora_b_weights=gate_up_b, + down_lora_a_weights=down_a, + down_lora_b_weights=down_b, ) continue diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index ab258d7182f1..e0fba6334617 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -237,6 +237,23 @@ def per_expert_lora_forward( Returns: output: [num_tokens, output_dim] - Base output + LoRA delta (in-place) """ + # DEBUG: Check inputs for NaNs + if torch.isnan(hidden_states).any(): + print(f"NaNs detected in hidden_states input! Shape: {hidden_states.shape}") + import pdb + + pdb.set_trace() + if lora_a_weights is not None and torch.isnan(lora_a_weights).any(): + print(f"NaNs detected in lora_a_weights!") + import pdb + + pdb.set_trace() + if lora_b_weights is not None and torch.isnan(lora_b_weights).any(): + print(f"NaNs detected in lora_b_weights!") + import pdb + + pdb.set_trace() + # Shapes num_tokens, input_dim = hidden_states.shape num_loras, _, output_dim, max_rank = lora_b_weights.shape @@ -312,4 +329,13 @@ def per_expert_lora_forward( BLOCK_SIZE=BLOCK_SIZE, ) + # DEBUG: Check output for NaNs + if torch.isnan(output).any(): + print(f"NaNs detected in per_expert_lora_forward output!") + print(f"Output shape: {output.shape}, is_down_proj: {is_down_proj}") + print(f"Input shape: {hidden_states.shape}, Output dim: {output_dim}") + import pdb + + pdb.set_trace() + return output From af73406bee331b8a21d586b02a077dab152d5360 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 28 Dec 2025 13:05:16 -0800 Subject: [PATCH 040/150] use intermediate tensors --- python/sglang/srt/lora/layers.py | 44 +++++++++---------- python/sglang/srt/lora/lora_manager.py | 21 --------- .../lora/triton_ops/per_expert_lora_moe.py | 25 ----------- 3 files changed, 20 insertions(+), 70 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index f521672ead19..e4c9678061b7 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -640,6 +640,7 @@ def _compute_lora_delta( Compute LoRA delta using per-expert LoRA weights and add to base_output in-place. Dispatch tokens to experts and compute per-expert deltas. + Uses intermediate caches similar to base MoE implementation for memory efficiency. """ from sglang.srt.lora.moe_dispatch import moe_dispatch from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( @@ -659,6 +660,7 @@ def _compute_lora_delta( lora_indices = self.lora_backend.forward_batch.token_lora_indices num_experts = self.base_layer.num_experts + num_tokens, hidden_size = hidden_states.shape # Dispatch tokens to experts token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( @@ -667,8 +669,22 @@ def _compute_lora_delta( lora_indices=lora_indices, ) + # Get intermediate dimension from LoRA B weights (gate_up output dim) + # gate_up_lora_b_weights shape: [num_loras, num_experts, intermediate_dim, max_rank] + _, _, intermediate_size, _ = self.gate_up_lora_b_weights.shape + + # Allocate intermediate cache for gate_up output (similar to intermediate_cache1 in base MoE) + # This stores the LoRA delta in intermediate space before down projection + num_dispatched = token_ids.shape[0] + lora_intermediate_cache = torch.empty( + (num_tokens, intermediate_size), + dtype=torch.float32, # Use float32 for LoRA accumulation like base implementation + device=hidden_states.device, + ) + # Apply gate_up_proj LoRA: hidden_states -> intermediate space - gate_up_output = per_expert_lora_forward( + # Store result in intermediate cache (no base_output means allocate new tensor) + per_expert_lora_forward( hidden_states=hidden_states, lora_a_weights=self.gate_up_lora_a_weights, lora_b_weights=self.gate_up_lora_b_weights, @@ -678,28 +694,17 @@ def _compute_lora_delta( lora_ranks=lora_ranks, lora_scalings=scalings, num_experts=num_experts, - base_output=None, + base_output=lora_intermediate_cache, # Store in our intermediate cache is_down_proj=False, ) - # DEBUG: Check for NaNs immediately after gate_up LoRA - if torch.isnan(gate_up_output).any(): - print(f"NaNs detected in gate_up_output! Shape: {gate_up_output.shape}") - print( - f"gate_up_output min/max: {gate_up_output.min()}, {gate_up_output.max()}" - ) - print(f"Input hidden_states shape: {hidden_states.shape}") - import pdb - - pdb.set_trace() - # Apply down_proj LoRA: intermediate space -> hidden space, added to base_output if ( self.down_lora_a_weights is not None and self.down_lora_b_weights is not None ): per_expert_lora_forward( - hidden_states=gate_up_output, + hidden_states=lora_intermediate_cache, # Use intermediate cache as input lora_a_weights=self.down_lora_a_weights, lora_b_weights=self.down_lora_b_weights, token_ids=token_ids, @@ -708,19 +713,10 @@ def _compute_lora_delta( lora_ranks=lora_ranks, lora_scalings=scalings, num_experts=num_experts, - base_output=base_output, + base_output=base_output, # Add directly to base_output in-place is_down_proj=True, ) - # DEBUG: Check for NaNs after down_proj LoRA - if torch.isnan(base_output).any(): - print(f"NaNs detected in base_output after down_proj LoRA!") - print(f"base_output min/max: {base_output.min()}, {base_output.max()}") - print(f"gate_up_output shape: {gate_up_output.shape}") - import pdb - - pdb.set_trace() - def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For MoE layers, tensor parallelism is typically not used # Return weights unchanged diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 058a04f665ae..022740a4376d 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -347,27 +347,6 @@ def update_lora_info(self): lora_type=LoRAType.LORA_B, ) - # DEBUG: Check LoRA weights for NaNs - if gate_up_a is not None and torch.isnan(gate_up_a).any(): - print(f"NaNs in gate_up LoRA_A weights!") - import pdb - - pdb.set_trace() - if gate_up_b is not None and torch.isnan(gate_up_b).any(): - print(f"NaNs in gate_up LoRA_B weights!") - import pdb - - pdb.set_trace() - if down_a is not None and torch.isnan(down_a).any(): - print(f"NaNs in down LoRA_A weights!") - import pdb - - pdb.set_trace() - if down_b is not None and torch.isnan(down_b).any(): - print(f"NaNs in down LoRA_B weights!") - import pdb - - pdb.set_trace() module.set_lora_info( gate_up_lora_a_weights=gate_up_a, diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index e0fba6334617..765fa8ea796b 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -237,22 +237,6 @@ def per_expert_lora_forward( Returns: output: [num_tokens, output_dim] - Base output + LoRA delta (in-place) """ - # DEBUG: Check inputs for NaNs - if torch.isnan(hidden_states).any(): - print(f"NaNs detected in hidden_states input! Shape: {hidden_states.shape}") - import pdb - - pdb.set_trace() - if lora_a_weights is not None and torch.isnan(lora_a_weights).any(): - print(f"NaNs detected in lora_a_weights!") - import pdb - - pdb.set_trace() - if lora_b_weights is not None and torch.isnan(lora_b_weights).any(): - print(f"NaNs detected in lora_b_weights!") - import pdb - - pdb.set_trace() # Shapes num_tokens, input_dim = hidden_states.shape @@ -329,13 +313,4 @@ def per_expert_lora_forward( BLOCK_SIZE=BLOCK_SIZE, ) - # DEBUG: Check output for NaNs - if torch.isnan(output).any(): - print(f"NaNs detected in per_expert_lora_forward output!") - print(f"Output shape: {output.shape}, is_down_proj: {is_down_proj}") - print(f"Input shape: {hidden_states.shape}, Output dim: {output_dim}") - import pdb - - pdb.set_trace() - return output From f5a22efff6ee6b32e64f8a6ebb86f9a58efa8fb3 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 28 Dec 2025 16:07:40 -0800 Subject: [PATCH 041/150] return LoRA addition as well --- python/sglang/srt/lora/layers.py | 10 +++--- .../lora/triton_ops/per_expert_lora_moe.py | 33 +++++++++++++++---- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index e4c9678061b7..dcad47998914 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -621,12 +621,15 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs 2. Parallel LoRA delta computation (if enabled, added in-place) 3. Return modified base_output """ + # Copy hidden_states for LoRA computation to ensure we use unmodified input + hidden_states_for_lora = hidden_states.clone() + # Run base MoE base_output = self.base_layer.forward(hidden_states, topk_output, **kwargs) # If LoRA is enabled, compute delta and add in-place for memory efficiency if self.set_lora and self.gate_up_lora_a_weights is not None: - self._compute_lora_delta(hidden_states, topk_output, base_output) + self._compute_lora_delta(hidden_states_for_lora, topk_output, base_output) return base_output @@ -675,7 +678,6 @@ def _compute_lora_delta( # Allocate intermediate cache for gate_up output (similar to intermediate_cache1 in base MoE) # This stores the LoRA delta in intermediate space before down projection - num_dispatched = token_ids.shape[0] lora_intermediate_cache = torch.empty( (num_tokens, intermediate_size), dtype=torch.float32, # Use float32 for LoRA accumulation like base implementation @@ -684,7 +686,7 @@ def _compute_lora_delta( # Apply gate_up_proj LoRA: hidden_states -> intermediate space # Store result in intermediate cache (no base_output means allocate new tensor) - per_expert_lora_forward( + _, _ = per_expert_lora_forward( hidden_states=hidden_states, lora_a_weights=self.gate_up_lora_a_weights, lora_b_weights=self.gate_up_lora_b_weights, @@ -703,7 +705,7 @@ def _compute_lora_delta( self.down_lora_a_weights is not None and self.down_lora_b_weights is not None ): - per_expert_lora_forward( + _, _ = per_expert_lora_forward( hidden_states=lora_intermediate_cache, # Use intermediate cache as input lora_a_weights=self.down_lora_a_weights, lora_b_weights=self.down_lora_b_weights, diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 765fa8ea796b..e469ffb6b249 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -25,7 +25,8 @@ def _per_expert_lora_kernel( hidden_states_ptr, # [num_total_tokens, input_dim] lora_a_weights_ptr, # [num_loras, num_experts, max_rank, input_dim] lora_b_weights_ptr, # [num_loras, num_experts, output_dim, max_rank] - output_ptr, # [num_total_tokens, output_dim] + output_ptr, # [num_total_tokens, output_dim] - base output (modified in-place) + lora_output_ptr, # [num_total_tokens, output_dim] - separate LoRA-only output # Dispatch info (length = num_dispatched) token_ids_ptr, # [num_dispatched] -> index into hidden/output expert_ids_ptr, # [num_dispatched] @@ -190,13 +191,22 @@ def _per_expert_lora_kernel( out_vals *= scaling # ---------------------------- - # Accumulate into global output + # Accumulate into global output (base_output) and store to lora_output # ---------------------------- out_row_base = actual_token_id * output_dim out_ptrs = output_ptr + out_row_base + out_offs + lora_out_ptrs = lora_output_ptr + out_row_base + out_offs + # Add to base_output in-place tl.atomic_add( out_ptrs, + out_vals.to(tl.float16), + mask=out_mask & has_rank, + ) + + # Also store to separate lora_output tensor + tl.atomic_add( + lora_out_ptrs, out_vals, mask=out_mask & has_rank, ) @@ -214,7 +224,7 @@ def per_expert_lora_forward( num_experts: int, base_output: torch.Tensor = None, is_down_proj: bool = False, -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for per-expert LoRA computation using a 3D Triton grid: grid = (spatial, slices, loras) @@ -235,7 +245,9 @@ def per_expert_lora_forward( or gate_up_proj (hidden_dim -> intermediate_dim) Returns: - output: [num_tokens, output_dim] - Base output + LoRA delta (in-place) + tuple of: + output: [num_tokens, output_dim] - Base output + LoRA delta (in-place) + lora_output: [num_tokens, output_dim] - Just the LoRA delta contribution """ # Shapes @@ -271,6 +283,14 @@ def per_expert_lora_forward( ), f"Expected shape ({num_tokens}, {output_dim}), got {output.shape}" assert output.device == device + # Allocate separate tensor for just the LoRA contribution + lora_output = torch.zeros( + num_tokens, + output_dim, + dtype=torch.float32, + device=device, + ) + # Tile size for hidden and output dimensions BLOCK_SIZE = 64 # tune as needed @@ -285,7 +305,8 @@ def per_expert_lora_forward( hidden_states, # hidden_states_ptr lora_a_weights, # lora_a_weights_ptr lora_b_weights, # lora_b_weights_ptr - output, # output_ptr + output, # output_ptr (base output, modified in-place) + lora_output, # lora_output_ptr (separate LoRA-only output) # Dispatch info token_ids, # token_ids_ptr expert_ids, # expert_ids_ptr @@ -313,4 +334,4 @@ def per_expert_lora_forward( BLOCK_SIZE=BLOCK_SIZE, ) - return output + return output, lora_output From 18476d7d583bfaae4334a9ada4f53071efad68c3 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 28 Dec 2025 16:33:43 -0800 Subject: [PATCH 042/150] fix atomic add issue --- .../srt/lora/triton_ops/per_expert_lora_moe.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index e469ffb6b249..9fa4ffaa8fb9 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -197,19 +197,14 @@ def _per_expert_lora_kernel( out_ptrs = output_ptr + out_row_base + out_offs lora_out_ptrs = lora_output_ptr + out_row_base + out_offs + # Compute combined mask + store_mask = out_mask & has_rank + # Add to base_output in-place - tl.atomic_add( - out_ptrs, - out_vals.to(tl.float16), - mask=out_mask & has_rank, - ) + tl.atomic_add(out_ptrs, out_vals.to(tl.float16), store_mask) # Also store to separate lora_output tensor - tl.atomic_add( - lora_out_ptrs, - out_vals, - mask=out_mask & has_rank, - ) + tl.atomic_add(lora_out_ptrs, out_vals, store_mask) def per_expert_lora_forward( From 0726a93dc240aeff58a1fbc686a7ea1fff2d1cd3 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 29 Dec 2025 07:56:52 -0800 Subject: [PATCH 043/150] Make sure all tensor types match --- python/sglang/srt/lora/layers.py | 2 +- python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index dcad47998914..342ff1387cd1 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -680,7 +680,7 @@ def _compute_lora_delta( # This stores the LoRA delta in intermediate space before down projection lora_intermediate_cache = torch.empty( (num_tokens, intermediate_size), - dtype=torch.float32, # Use float32 for LoRA accumulation like base implementation + dtype=hidden_states.dtype, # Use consistent dtype with model device=hidden_states.device, ) diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 9fa4ffaa8fb9..a23a932e169a 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -252,6 +252,9 @@ def per_expert_lora_forward( # Make sure everything is on the same device and contiguous device = hidden_states.device + + # Use hidden_states dtype for consistency with model + dtype = hidden_states.dtype hidden_states = hidden_states.contiguous() lora_a_weights = lora_a_weights.contiguous() lora_b_weights = lora_b_weights.contiguous() @@ -263,11 +266,11 @@ def per_expert_lora_forward( # Initialize or reuse output tensor for in-place addition if base_output is None: - # Use float32 for accumulation; you can cast back if needed + # Use specified dtype for consistency with model output = torch.zeros( num_tokens, output_dim, - dtype=torch.float32, + dtype=dtype, device=device, ) else: @@ -282,7 +285,7 @@ def per_expert_lora_forward( lora_output = torch.zeros( num_tokens, output_dim, - dtype=torch.float32, + dtype=dtype, device=device, ) From b211fb034bf504be029979d92ce2ccab353170b5 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 29 Dec 2025 10:34:11 -0800 Subject: [PATCH 044/150] make sure types match --- python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index a23a932e169a..ba9549d767fc 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -200,11 +200,14 @@ def _per_expert_lora_kernel( # Compute combined mask store_mask = out_mask & has_rank + # Convert to output dtype (matches hidden_states dtype, could be float16/bfloat16/float32) + out_vals_typed = out_vals.to(output_ptr.dtype.element_ty) + # Add to base_output in-place - tl.atomic_add(out_ptrs, out_vals.to(tl.float16), store_mask) + tl.atomic_add(out_ptrs, out_vals_typed, store_mask) - # Also store to separate lora_output tensor - tl.atomic_add(lora_out_ptrs, out_vals, store_mask) + # Also store to separate lora_output tensor (same dtype) + tl.atomic_add(lora_out_ptrs, out_vals_typed, store_mask) def per_expert_lora_forward( From cef94603a12972d7b4ce87adebf0ee697859230c Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 29 Dec 2025 10:50:28 -0800 Subject: [PATCH 045/150] Add topk weights multiplications --- python/sglang/srt/lora/layers.py | 4 ++++ .../lora/triton_ops/per_expert_lora_moe.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 342ff1387cd1..cd9472600fec 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -686,6 +686,7 @@ def _compute_lora_delta( # Apply gate_up_proj LoRA: hidden_states -> intermediate space # Store result in intermediate cache (no base_output means allocate new tensor) + # Note: topk_weights are NOT applied here - they are applied on the final down_proj output _, _ = per_expert_lora_forward( hidden_states=hidden_states, lora_a_weights=self.gate_up_lora_a_weights, @@ -698,9 +699,11 @@ def _compute_lora_delta( num_experts=num_experts, base_output=lora_intermediate_cache, # Store in our intermediate cache is_down_proj=False, + topk_weights=None, # No router weight multiplication for gate_up ) # Apply down_proj LoRA: intermediate space -> hidden space, added to base_output + # Router weights (topk_weights) are applied here to scale each expert's contribution if ( self.down_lora_a_weights is not None and self.down_lora_b_weights is not None @@ -717,6 +720,7 @@ def _compute_lora_delta( num_experts=num_experts, base_output=base_output, # Add directly to base_output in-place is_down_proj=True, + topk_weights=sorted_topk_weights, # Apply router weights to final output ) def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index ba9549d767fc..9e5d083217cb 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -31,6 +31,7 @@ def _per_expert_lora_kernel( token_ids_ptr, # [num_dispatched] -> index into hidden/output expert_ids_ptr, # [num_dispatched] lora_ids_ptr, # [num_dispatched] + topk_weights_ptr, # [num_dispatched] - Router weights for each dispatched token # Dimensions input_dim: tl.constexpr, output_dim: tl.constexpr, @@ -53,6 +54,8 @@ def _per_expert_lora_kernel( lora_scalings_ptr, # Block size (used for input and output tiling; rank is not tiled) BLOCK_SIZE: tl.constexpr, + # Whether to multiply by router weights + MUL_ROUTED_WEIGHT: tl.constexpr, ): """ Compute per-expert LoRA delta: @@ -190,6 +193,11 @@ def _per_expert_lora_kernel( # Apply scaling out_vals *= scaling + # Apply router weight if enabled (matches base MoE behavior) + if MUL_ROUTED_WEIGHT: + topk_weight = tl.load(topk_weights_ptr + spatial_id) + out_vals *= topk_weight + # ---------------------------- # Accumulate into global output (base_output) and store to lora_output # ---------------------------- @@ -222,6 +230,7 @@ def per_expert_lora_forward( num_experts: int, base_output: torch.Tensor = None, is_down_proj: bool = False, + topk_weights: torch.Tensor = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for per-expert LoRA computation using a 3D Triton grid: @@ -241,6 +250,8 @@ def per_expert_lora_forward( base_output: [num_tokens, output_dim] - Base MoE output (modified in-place) is_down_proj: Whether this is for down_proj (intermediate_dim -> hidden_dim) or gate_up_proj (hidden_dim -> intermediate_dim) + topk_weights: [num_dispatched] - Router weights for each dispatched token. + Always multiplied by output (router weights are applied to final output). Returns: tuple of: @@ -267,6 +278,14 @@ def per_expert_lora_forward( lora_ranks = lora_ranks.contiguous() lora_scalings = lora_scalings.contiguous() + # Handle topk_weights (always provided, but only applied for down_proj) + mul_routed_weight = is_down_proj # Apply router weights only to down_proj output + if topk_weights is not None: + topk_weights = topk_weights.contiguous() + else: + # Create dummy tensor if not provided + topk_weights = torch.empty(0, device=device, dtype=dtype) + # Initialize or reuse output tensor for in-place addition if base_output is None: # Use specified dtype for consistency with model @@ -312,6 +331,7 @@ def per_expert_lora_forward( token_ids, # token_ids_ptr expert_ids, # expert_ids_ptr lora_ids, # lora_ids_ptr + topk_weights, # topk_weights_ptr # Dimensions input_dim, # input_dim (hidden_dim for gate_up_proj, intermediate_dim for down_proj) output_dim, # output_dim (intermediate_dim for gate_up_proj, hidden_dim for down_proj) @@ -333,6 +353,8 @@ def per_expert_lora_forward( lora_scalings, # lora_scalings_ptr # Block size (constexpr) BLOCK_SIZE=BLOCK_SIZE, + # Router weight multiplication flag + MUL_ROUTED_WEIGHT=mul_routed_weight, ) return output, lora_output From 4e688e303d34c61b3c113d972803a9e9e1edb4af Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 29 Dec 2025 17:45:40 -0800 Subject: [PATCH 046/150] fix max_rank issues --- python/sglang/srt/lora/mem_pool.py | 38 +++++++++++++------ .../lora/triton_ops/per_expert_lora_moe.py | 35 ++++++++++++----- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 7327fa52742b..f143a8361efe 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -125,7 +125,13 @@ def is_moe_module(self, module_name: str) -> bool: """Check if module is part of MoE experts.""" return "moe" in module_name - def _get_standard_shape(self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int, layer_idx: int) -> Tuple[int]: + def _get_standard_shape( + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, + ) -> Tuple[int]: """Get 3D shape for standard (non-MoE) modules.""" input_dim, _ = get_hidden_dim( module_name, self.base_hf_config, base_model, layer_idx @@ -163,10 +169,15 @@ def get_lora_A_shape( "num_local_experts", getattr(self.base_hf_config, "num_experts", 0), ) + # Allocate all MoE buffers with the same maximum rank dimension + # to ensure consistent kernel compilation. The maximum stacking factor is 2. + max_rank_dim = ( + max_lora_dim * 2 + ) # Accommodate maximum stacking (gate_up_proj) return ( self.max_loras_per_batch, num_experts, - max_lora_dim * c, + max_rank_dim, input_dim, ) else: @@ -246,11 +257,13 @@ def init_buffer( get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): # Check if model has both shared experts and MoE experts - has_shared_experts = hasattr(base_model.config, 'shared_expert_intermediate_size') and \ - base_model.config.shared_expert_intermediate_size > 0 + has_shared_experts = ( + hasattr(base_model.config, "shared_expert_intermediate_size") + and base_model.config.shared_expert_intermediate_size > 0 + ) has_moe = getattr(base_model.config, "num_experts", 1) > 1 - # Shape functions automatically handle both 3D (standard) and 4D (MoE) + # Shape functions automatically handle both 3D (standard) and 4D (MoE) target_modules = target_modules - set(EMBEDDING_NAMES) for module_name in target_modules: # Special handling for ambiguous target modules that can be in different contexts @@ -261,7 +274,9 @@ def init_buffer( shared_key = module_name buffer[shared_key] = [ torch.empty( - get_lora_shape_fn(module_name, base_model, self.max_lora_rank, idx), + get_lora_shape_fn( + module_name, base_model, self.max_lora_rank, idx + ), dtype=self.dtype, device=device, ) @@ -272,7 +287,9 @@ def init_buffer( moe_key = f"{module_name}_moe" buffer[moe_key] = [ torch.empty( - get_lora_shape_fn(moe_key, base_model, self.max_lora_rank, idx), + get_lora_shape_fn( + moe_key, base_model, self.max_lora_rank, idx + ), dtype=self.dtype, device=device, ) @@ -522,16 +539,16 @@ def load_lora_weight_tensor( # TODO (Jonahcb): check if the code can be refactored to avoid the special handling for FusedMoEWithLoRA # Handle FusedMoEWithLoRA specially - it contains multiple target modules from sglang.srt.lora.layers import FusedMoEWithLoRA + if isinstance(module, FusedMoEWithLoRA): # FusedMoEWithLoRA contains both gate_up_proj and down_proj - moe_target_modules = ['gate_up_proj_moe', 'down_proj_moe'] + moe_target_modules = ["gate_up_proj_moe", "down_proj_moe"] for target_module in moe_target_modules: - if temp_A_buffer[target_module] is None: # Skip weight slicing if the weight is not present in the adapter continue - # Handle MoE modules (they contain dicts of per-expert tensors) + # Handle MoE modules (they contain dicts of per-expert tensors) # Slice each expert's weights individually for expert_id in temp_A_buffer[target_module].keys(): temp_A_buffer[target_module][expert_id] = ( @@ -618,7 +635,6 @@ def load_lora_weight_tensor( load_lora_weight_tensor(buffer_view, weights) if lora_adapter.embedding_layers: - org_vocab_size = self.base_hf_config.vocab_size lora_added_tokens_size = lora_adapter.config.lora_added_tokens_size # Only when LoRA is applied to the embedding layer will it have the extra-token issue that needs to be resolved. diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 9e5d083217cb..bb18bf9f6e54 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -56,6 +56,8 @@ def _per_expert_lora_kernel( BLOCK_SIZE: tl.constexpr, # Whether to multiply by router weights MUL_ROUTED_WEIGHT: tl.constexpr, + # Whether this is down_proj (affects stacking factor for rank calculation) + IS_DOWN_PROJ: tl.constexpr, ): """ Compute per-expert LoRA delta: @@ -89,7 +91,13 @@ def _per_expert_lora_kernel( # Load LoRA rank and scaling (scalar tensors) for this LoRA adapter rank = tl.load(lora_ranks_ptr + lora_id_grid) scaling = tl.load(lora_scalings_ptr + lora_id_grid) - has_rank = rank > 0 + + # Adjust rank for stacked modules (gate_up_proj has stacking factor 2) + effective_rank = rank + if not IS_DOWN_PROJ: # gate_up_proj case + effective_rank = rank * 2 + + has_rank = effective_rank > 0 if not has_rank: return @@ -119,7 +127,7 @@ def _per_expert_lora_kernel( # We assume max_rank is small enough to keep as a single 1D vector r_offs = tl.arange(0, max_rank) # [max_rank] - rank_mask = r_offs < rank # [max_rank] + rank_mask = r_offs < effective_rank # [max_rank] # Accumulator for intermediate: [max_rank] intermediate = tl.zeros((max_rank,), dtype=tl.float32) @@ -145,7 +153,8 @@ def _per_expert_lora_kernel( a_ptrs = ( lora_a_base + r_offs[:, None] * lora_a_stride_rank - + input_offs[None, :] * lora_a_stride_input + + input_offs[None, :] + * lora_a_stride_input # check if it is necessary to multiply by stride value as it should be contigious in this dimension ) a_vals = tl.load( a_ptrs, @@ -195,7 +204,9 @@ def _per_expert_lora_kernel( # Apply router weight if enabled (matches base MoE behavior) if MUL_ROUTED_WEIGHT: - topk_weight = tl.load(topk_weights_ptr + spatial_id) + topk_weight = tl.load( + topk_weights_ptr + spatial_id + ) # I don't think this correctly understands how top_k weights is organized (now moe dispatch reorganizes topk weights to work w this) out_vals *= topk_weight # ---------------------------- @@ -205,17 +216,14 @@ def _per_expert_lora_kernel( out_ptrs = output_ptr + out_row_base + out_offs lora_out_ptrs = lora_output_ptr + out_row_base + out_offs - # Compute combined mask - store_mask = out_mask & has_rank - # Convert to output dtype (matches hidden_states dtype, could be float16/bfloat16/float32) out_vals_typed = out_vals.to(output_ptr.dtype.element_ty) # Add to base_output in-place - tl.atomic_add(out_ptrs, out_vals_typed, store_mask) + tl.atomic_add(out_ptrs, out_vals_typed, out_mask) # Also store to separate lora_output tensor (same dtype) - tl.atomic_add(lora_out_ptrs, out_vals_typed, store_mask) + tl.atomic_add(lora_out_ptrs, out_vals_typed, out_mask) def per_expert_lora_forward( @@ -261,9 +269,14 @@ def per_expert_lora_forward( # Shapes num_tokens, input_dim = hidden_states.shape - num_loras, _, output_dim, max_rank = lora_b_weights.shape + num_loras, _, output_dim, _ = lora_b_weights.shape num_dispatched = token_ids.shape[0] + # Use fixed max_rank for consistent kernel compilation + # Maximum stacking factor is 2 (for gate_up_proj), so max_rank = max_lora_rank * 2 + # We assume max_lora_rank is reasonably small (e.g., 64-128) so max_rank = 256 is safe + max_rank = 256 # Conservative upper bound for max_lora_rank * 2 + # Make sure everything is on the same device and contiguous device = hidden_states.device @@ -355,6 +368,8 @@ def per_expert_lora_forward( BLOCK_SIZE=BLOCK_SIZE, # Router weight multiplication flag MUL_ROUTED_WEIGHT=mul_routed_weight, + # Whether this is down_proj + IS_DOWN_PROJ=is_down_proj, ) return output, lora_output From c616f58d72e4cf91b462ae23af8e1740159575aa Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 29 Dec 2025 18:15:33 -0800 Subject: [PATCH 047/150] clean up debugging code --- python/sglang/srt/lora/layers.py | 5 +---- python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py | 8 +------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index cd9472600fec..ca04fa8ed2c7 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -621,15 +621,12 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs 2. Parallel LoRA delta computation (if enabled, added in-place) 3. Return modified base_output """ - # Copy hidden_states for LoRA computation to ensure we use unmodified input - hidden_states_for_lora = hidden_states.clone() - # Run base MoE base_output = self.base_layer.forward(hidden_states, topk_output, **kwargs) # If LoRA is enabled, compute delta and add in-place for memory efficiency if self.set_lora and self.gate_up_lora_a_weights is not None: - self._compute_lora_delta(hidden_states_for_lora, topk_output, base_output) + self._compute_lora_delta(hidden_states, topk_output, base_output) return base_output diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index bb18bf9f6e54..93bcec5e3ecf 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -129,6 +129,7 @@ def _per_expert_lora_kernel( r_offs = tl.arange(0, max_rank) # [max_rank] rank_mask = r_offs < effective_rank # [max_rank] + # TODO (Jonahcb): check if it is better to allocate outside the kernel # Accumulator for intermediate: [max_rank] intermediate = tl.zeros((max_rank,), dtype=tl.float32) @@ -174,12 +175,6 @@ def _per_expert_lora_kernel( out_offs = out_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] out_mask = out_offs < output_dim # [BLOCK_SIZE] - # If this slice is entirely out of bounds, we can early-exit - # (not strictly necessary but cheap) - # NOTE: Triton doesn't have a direct "if not any(mask)" primitive, - # but the mask will naturally guard loads/stores below, so this is safe to omit. - # We'll just rely on masks. - # Build [max_rank, BLOCK_SIZE] tile of B: # rows: r_offs (rank dimension) # cols: out_offs (output dimension) @@ -195,7 +190,6 @@ def _per_expert_lora_kernel( other=0.0, ).to(tl.float32) - # Contribution: # out_vals[j] = sum_r B[j, r] * intermediate[r] out_vals = tl.sum(b_vals * intermediate[:, None], axis=0) # [BLOCK_SIZE] From ac1ffead9720a95fa55cc959bbfdadf087c079f6 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 30 Dec 2025 16:05:06 -0800 Subject: [PATCH 048/150] use torch.zeros --- python/sglang/srt/lora/layers.py | 2 +- python/sglang/srt/lora/lora_manager.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index ca04fa8ed2c7..2f0bfbc94e88 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -675,7 +675,7 @@ def _compute_lora_delta( # Allocate intermediate cache for gate_up output (similar to intermediate_cache1 in base MoE) # This stores the LoRA delta in intermediate space before down projection - lora_intermediate_cache = torch.empty( + lora_intermediate_cache = torch.zeros( (num_tokens, intermediate_size), dtype=hidden_states.dtype, # Use consistent dtype with model device=hidden_states.device, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 022740a4376d..e04e084912a8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -347,7 +347,6 @@ def update_lora_info(self): lora_type=LoRAType.LORA_B, ) - module.set_lora_info( gate_up_lora_a_weights=gate_up_a, gate_up_lora_b_weights=gate_up_b, From 39c9316b4d22a7343327190bd1d7d757452121dc Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 30 Dec 2025 16:34:50 -0800 Subject: [PATCH 049/150] fix --- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index d51f3763e628..ebb4b784798d 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -536,8 +536,6 @@ def test_moe_lora_logprob_comparison_basic(self): """ Test comparing HF and SGLang MoE LoRA logprobs with basic prompts. """ - if is_in_ci(): - self.skipTest("Skipping in CI environment - requires large MoE models") model_path = "Qwen/Qwen1.5-MoE-A2.7B" lora_paths = ["sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest"] @@ -554,8 +552,6 @@ def test_moe_lora_logprob_comparison_full(self): """ Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. """ - if is_in_ci(): - self.skipTest("Skipping in CI environment - requires large MoE models") model_path = "Qwen/Qwen1.5-MoE-A2.7B" lora_paths = ["sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest"] From dfe69e9574cc697ff07a290066d3d09f6daa2531 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 30 Dec 2025 16:38:13 -0800 Subject: [PATCH 050/150] fix --- python/sglang/srt/lora/layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 2f0bfbc94e88..957e921fec2f 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -621,12 +621,13 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs 2. Parallel LoRA delta computation (if enabled, added in-place) 3. Return modified base_output """ + hidden_states_for_lora = hidden_states.clone() # Run base MoE base_output = self.base_layer.forward(hidden_states, topk_output, **kwargs) # If LoRA is enabled, compute delta and add in-place for memory efficiency if self.set_lora and self.gate_up_lora_a_weights is not None: - self._compute_lora_delta(hidden_states, topk_output, base_output) + self._compute_lora_delta(hidden_states_for_lora, topk_output, base_output) return base_output From 8c41ff9dd76a147b09e773aedb14e8fa153c896a Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 30 Dec 2025 17:58:01 -0800 Subject: [PATCH 051/150] add comments for clarity --- python/sglang/srt/lora/layers.py | 1 + python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 957e921fec2f..5149e7b6e247 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -685,6 +685,7 @@ def _compute_lora_delta( # Apply gate_up_proj LoRA: hidden_states -> intermediate space # Store result in intermediate cache (no base_output means allocate new tensor) # Note: topk_weights are NOT applied here - they are applied on the final down_proj output + # TODO (Jonahcb): remove return values after done debugging _, _ = per_expert_lora_forward( hidden_states=hidden_states, lora_a_weights=self.gate_up_lora_a_weights, diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 93bcec5e3ecf..8dccb460949b 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -216,6 +216,7 @@ def _per_expert_lora_kernel( # Add to base_output in-place tl.atomic_add(out_ptrs, out_vals_typed, out_mask) + # TODO (Jonahcb): remove unnecessary store to lora_output tensor after done debugging # Also store to separate lora_output tensor (same dtype) tl.atomic_add(lora_out_ptrs, out_vals_typed, out_mask) From 5b3f5aaafabd73e0744ccd0610e0fd5ca3df7712 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 31 Dec 2025 16:46:46 -0800 Subject: [PATCH 052/150] add activation function --- python/sglang/srt/lora/layers.py | 113 ++++++++++++++---- .../lora/triton_ops/per_expert_lora_moe.py | 69 +++++++---- 2 files changed, 138 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 5149e7b6e247..f01167ca04b4 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -24,6 +24,20 @@ from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo +from sglang.srt.utils import is_cuda, is_hip, is_cpu, cpu_has_amx_support + +# Import activation functions for LoRA (following Triton runner pattern) +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_cpu = is_cpu() +_is_cpu_amx_available = cpu_has_amx_support() + +if _is_cuda: + from sgl_kernel import gelu_and_mul, silu_and_mul +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_hip: + from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul class BaseLayerWithLoRA(nn.Module): @@ -670,22 +684,24 @@ def _compute_lora_delta( lora_indices=lora_indices, ) - # Get intermediate dimension from LoRA B weights (gate_up output dim) - # gate_up_lora_b_weights shape: [num_loras, num_experts, intermediate_dim, max_rank] - _, _, intermediate_size, _ = self.gate_up_lora_b_weights.shape - - # Allocate intermediate cache for gate_up output (similar to intermediate_cache1 in base MoE) - # This stores the LoRA delta in intermediate space before down projection - lora_intermediate_cache = torch.zeros( - (num_tokens, intermediate_size), - dtype=hidden_states.dtype, # Use consistent dtype with model + # Get dimensions from LoRA weights + # gate_up_lora_b_weights shape: [num_loras, num_experts, gate_up_dim, max_rank] + # where gate_up_dim = 2 * intermediate_dim (gate + up combined) + _, _, gate_up_dim, _ = self.gate_up_lora_b_weights.shape + intermediate_dim = gate_up_dim // 2 # After activation, dimension halves + + # Get number of dispatched (token, expert) pairs + num_dispatched = token_ids.shape[0] + + # Keep expert outputs separate until final reduction + # Stage 1: gate_up_proj LoRA - keep experts separate + # Shape: (num_dispatched, gate_up_dim) where each row is one (token, expert) pair + lora_intermediate_cache1 = torch.zeros( + (num_dispatched, gate_up_dim), + dtype=hidden_states.dtype, device=hidden_states.device, ) - # Apply gate_up_proj LoRA: hidden_states -> intermediate space - # Store result in intermediate cache (no base_output means allocate new tensor) - # Note: topk_weights are NOT applied here - they are applied on the final down_proj output - # TODO (Jonahcb): remove return values after done debugging _, _ = per_expert_lora_forward( hidden_states=hidden_states, lora_a_weights=self.gate_up_lora_a_weights, @@ -696,19 +712,57 @@ def _compute_lora_delta( lora_ranks=lora_ranks, lora_scalings=scalings, num_experts=num_experts, - base_output=lora_intermediate_cache, # Store in our intermediate cache + base_output=lora_intermediate_cache1, is_down_proj=False, - topk_weights=None, # No router weight multiplication for gate_up + topk_weights=None, + keep_experts_separate=True, # Keep each (token, expert) pair separate + ) + + # Stage 2: Apply activation to each (token, expert) pair separately + # Output shape: (num_dispatched, intermediate_dim) - dimension halves due to SiLU/GeGLU + lora_intermediate_cache2 = torch.zeros( + (num_dispatched, intermediate_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, ) - # Apply down_proj LoRA: intermediate space -> hidden space, added to base_output - # Router weights (topk_weights) are applied here to scale each expert's contribution + activation = self.base_layer.moe_runner_config.activation + + if activation == "silu": + if _is_cuda: + silu_and_mul(lora_intermediate_cache1, lora_intermediate_cache2) + elif _is_hip: + vllm_ops.silu_and_mul( + lora_intermediate_cache2, lora_intermediate_cache1 + ) + else: + raise ValueError(f"Unsupported activation: {activation=}") + elif activation == "gelu": + if _is_cuda: + gelu_and_mul(lora_intermediate_cache1, lora_intermediate_cache2) + elif _is_hip: + vllm_ops.gelu_and_mul( + lora_intermediate_cache2, lora_intermediate_cache1 + ) + else: + raise ValueError(f"Unsupported activation: {activation=}") + else: + raise ValueError(f"Unsupported activation: {activation=}") + + # Stage 3: down_proj LoRA - keep experts separate + # Shape: (num_dispatched, hidden_size) if ( self.down_lora_a_weights is not None and self.down_lora_b_weights is not None ): + lora_intermediate_cache3 = torch.zeros( + (num_dispatched, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + _, _ = per_expert_lora_forward( - hidden_states=lora_intermediate_cache, # Use intermediate cache as input + hidden_states=lora_intermediate_cache2, lora_a_weights=self.down_lora_a_weights, lora_b_weights=self.down_lora_b_weights, token_ids=token_ids, @@ -717,9 +771,28 @@ def _compute_lora_delta( lora_ranks=lora_ranks, lora_scalings=scalings, num_experts=num_experts, - base_output=base_output, # Add directly to base_output in-place + base_output=lora_intermediate_cache3, is_down_proj=True, - topk_weights=sorted_topk_weights, # Apply router weights to final output + topk_weights=None, # Don't apply weights in kernel + keep_experts_separate=True, # Keep each (token, expert) pair separate + ) + + # Stage 4: Final reduction - combine expert outputs with router weights + # Similar to moe_sum_reduce in base Triton MoE + # For each token, sum: output[t] += Σ_k (cache3[d] * topk_weights[d]) + # where d iterates over all dispatched pairs for token t + + # Apply router weights to each (token, expert) output + weighted_outputs = lora_intermediate_cache3 * sorted_topk_weights.unsqueeze( + -1 + ) + + # Scatter-add to combine experts per token + # token_ids[d] tells us which token row to add weighted_outputs[d] to + base_output.scatter_add_( + 0, + token_ids.unsqueeze(-1).expand(-1, hidden_size), + weighted_outputs, ) def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 8dccb460949b..667b57fdf944 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -58,6 +58,8 @@ def _per_expert_lora_kernel( MUL_ROUTED_WEIGHT: tl.constexpr, # Whether this is down_proj (affects stacking factor for rank calculation) IS_DOWN_PROJ: tl.constexpr, + # Whether to keep expert outputs separate + KEEP_EXPERTS_SEPARATE: tl.constexpr, ): """ Compute per-expert LoRA delta: @@ -206,19 +208,27 @@ def _per_expert_lora_kernel( # ---------------------------- # Accumulate into global output (base_output) and store to lora_output # ---------------------------- - out_row_base = actual_token_id * output_dim - out_ptrs = output_ptr + out_row_base + out_offs - lora_out_ptrs = lora_output_ptr + out_row_base + out_offs - # Convert to output dtype (matches hidden_states dtype, could be float16/bfloat16/float32) out_vals_typed = out_vals.to(output_ptr.dtype.element_ty) - # Add to base_output in-place - tl.atomic_add(out_ptrs, out_vals_typed, out_mask) - - # TODO (Jonahcb): remove unnecessary store to lora_output tensor after done debugging - # Also store to separate lora_output tensor (same dtype) - tl.atomic_add(lora_out_ptrs, out_vals_typed, out_mask) + if KEEP_EXPERTS_SEPARATE: + # Write to spatial_id position (keeps each (token, expert) pair separate) + out_row_base = spatial_id * output_dim + out_ptrs = output_ptr + out_row_base + out_offs + lora_out_ptrs = lora_output_ptr + out_row_base + out_offs + # Use regular store (not atomic) since each spatial_id is unique + tl.store(out_ptrs, out_vals_typed, out_mask) + tl.store(lora_out_ptrs, out_vals_typed, out_mask) + else: + # Write to actual_token_id position (combines experts per token via atomic_add) + out_row_base = actual_token_id * output_dim + out_ptrs = output_ptr + out_row_base + out_offs + lora_out_ptrs = lora_output_ptr + out_row_base + out_offs + # Add to base_output in-place + tl.atomic_add(out_ptrs, out_vals_typed, out_mask) + # TODO (Jonahcb): remove unnecessary store to lora_output tensor after done debugging + # Also store to separate lora_output tensor (same dtype) + tl.atomic_add(lora_out_ptrs, out_vals_typed, out_mask) def per_expert_lora_forward( @@ -234,6 +244,7 @@ def per_expert_lora_forward( base_output: torch.Tensor = None, is_down_proj: bool = False, topk_weights: torch.Tensor = None, + keep_experts_separate: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for per-expert LoRA computation using a 3D Triton grid: @@ -250,16 +261,20 @@ def per_expert_lora_forward( lora_ranks: [num_loras] - Rank for each LoRA lora_scalings: [num_loras] - Scaling factor for each LoRA num_experts: Total number of experts - base_output: [num_tokens, output_dim] - Base MoE output (modified in-place) + base_output: Output tensor - shape depends on keep_experts_separate: + [num_tokens, output_dim] if False (combined) + [num_dispatched, output_dim] if True (separate) is_down_proj: Whether this is for down_proj (intermediate_dim -> hidden_dim) or gate_up_proj (hidden_dim -> intermediate_dim) topk_weights: [num_dispatched] - Router weights for each dispatched token. - Always multiplied by output (router weights are applied to final output). + Only applied when is_down_proj=True and keep_experts_separate=False. + keep_experts_separate: If True, keeps each (token, expert) pair's output separate + (like base Triton MoE). If False, combines experts per token. Returns: tuple of: - output: [num_tokens, output_dim] - Base output + LoRA delta (in-place) - lora_output: [num_tokens, output_dim] - Just the LoRA delta contribution + output: Base output + LoRA delta + lora_output: Just the LoRA delta contribution """ # Shapes @@ -286,35 +301,39 @@ def per_expert_lora_forward( lora_ranks = lora_ranks.contiguous() lora_scalings = lora_scalings.contiguous() - # Handle topk_weights (always provided, but only applied for down_proj) - mul_routed_weight = is_down_proj # Apply router weights only to down_proj output + # Handle topk_weights (only applied for down_proj when combining experts) + # Don't apply router weights in kernel when keeping experts separate + mul_routed_weight = is_down_proj and not keep_experts_separate if topk_weights is not None: topk_weights = topk_weights.contiguous() else: # Create dummy tensor if not provided topk_weights = torch.empty(0, device=device, dtype=dtype) + # Determine output shape based on whether we keep experts separate + if keep_experts_separate: + output_shape = (num_dispatched, output_dim) + else: + output_shape = (num_tokens, output_dim) + # Initialize or reuse output tensor for in-place addition if base_output is None: # Use specified dtype for consistency with model output = torch.zeros( - num_tokens, - output_dim, + *output_shape, dtype=dtype, device=device, ) else: output = base_output - assert output.shape == ( - num_tokens, - output_dim, - ), f"Expected shape ({num_tokens}, {output_dim}), got {output.shape}" + assert ( + output.shape == output_shape + ), f"Expected shape {output_shape}, got {output.shape}" assert output.device == device # Allocate separate tensor for just the LoRA contribution lora_output = torch.zeros( - num_tokens, - output_dim, + *output_shape, dtype=dtype, device=device, ) @@ -365,6 +384,8 @@ def per_expert_lora_forward( MUL_ROUTED_WEIGHT=mul_routed_weight, # Whether this is down_proj IS_DOWN_PROJ=is_down_proj, + # Whether to keep expert outputs separate + KEEP_EXPERTS_SEPARATE=keep_experts_separate, ) return output, lora_output From bbae67dae95a899ab1a05d029f34b5fdc5cd0039 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 31 Dec 2025 16:55:18 -0800 Subject: [PATCH 053/150] remove unused parameters --- python/sglang/srt/lora/layers.py | 4 - .../lora/triton_ops/per_expert_lora_moe.py | 78 +++++-------------- 2 files changed, 19 insertions(+), 63 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index f01167ca04b4..483155e9edd1 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -714,8 +714,6 @@ def _compute_lora_delta( num_experts=num_experts, base_output=lora_intermediate_cache1, is_down_proj=False, - topk_weights=None, - keep_experts_separate=True, # Keep each (token, expert) pair separate ) # Stage 2: Apply activation to each (token, expert) pair separately @@ -773,8 +771,6 @@ def _compute_lora_delta( num_experts=num_experts, base_output=lora_intermediate_cache3, is_down_proj=True, - topk_weights=None, # Don't apply weights in kernel - keep_experts_separate=True, # Keep each (token, expert) pair separate ) # Stage 4: Final reduction - combine expert outputs with router weights diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py index 667b57fdf944..5537c3aa89ec 100644 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py @@ -31,7 +31,6 @@ def _per_expert_lora_kernel( token_ids_ptr, # [num_dispatched] -> index into hidden/output expert_ids_ptr, # [num_dispatched] lora_ids_ptr, # [num_dispatched] - topk_weights_ptr, # [num_dispatched] - Router weights for each dispatched token # Dimensions input_dim: tl.constexpr, output_dim: tl.constexpr, @@ -54,12 +53,8 @@ def _per_expert_lora_kernel( lora_scalings_ptr, # Block size (used for input and output tiling; rank is not tiled) BLOCK_SIZE: tl.constexpr, - # Whether to multiply by router weights - MUL_ROUTED_WEIGHT: tl.constexpr, # Whether this is down_proj (affects stacking factor for rank calculation) IS_DOWN_PROJ: tl.constexpr, - # Whether to keep expert outputs separate - KEEP_EXPERTS_SEPARATE: tl.constexpr, ): """ Compute per-expert LoRA delta: @@ -198,37 +193,21 @@ def _per_expert_lora_kernel( # Apply scaling out_vals *= scaling - # Apply router weight if enabled (matches base MoE behavior) - if MUL_ROUTED_WEIGHT: - topk_weight = tl.load( - topk_weights_ptr + spatial_id - ) # I don't think this correctly understands how top_k weights is organized (now moe dispatch reorganizes topk weights to work w this) - out_vals *= topk_weight + # Router weights are applied in the final reduction step, not in the kernel # ---------------------------- - # Accumulate into global output (base_output) and store to lora_output + # Store results for each (token, expert) pair separately # ---------------------------- # Convert to output dtype (matches hidden_states dtype, could be float16/bfloat16/float32) out_vals_typed = out_vals.to(output_ptr.dtype.element_ty) - if KEEP_EXPERTS_SEPARATE: - # Write to spatial_id position (keeps each (token, expert) pair separate) - out_row_base = spatial_id * output_dim - out_ptrs = output_ptr + out_row_base + out_offs - lora_out_ptrs = lora_output_ptr + out_row_base + out_offs - # Use regular store (not atomic) since each spatial_id is unique - tl.store(out_ptrs, out_vals_typed, out_mask) - tl.store(lora_out_ptrs, out_vals_typed, out_mask) - else: - # Write to actual_token_id position (combines experts per token via atomic_add) - out_row_base = actual_token_id * output_dim - out_ptrs = output_ptr + out_row_base + out_offs - lora_out_ptrs = lora_output_ptr + out_row_base + out_offs - # Add to base_output in-place - tl.atomic_add(out_ptrs, out_vals_typed, out_mask) - # TODO (Jonahcb): remove unnecessary store to lora_output tensor after done debugging - # Also store to separate lora_output tensor (same dtype) - tl.atomic_add(lora_out_ptrs, out_vals_typed, out_mask) + # Write to spatial_id position (each (token, expert) pair gets its own row) + out_row_base = spatial_id * output_dim + out_ptrs = output_ptr + out_row_base + out_offs + lora_out_ptrs = lora_output_ptr + out_row_base + out_offs + # Use regular store since each spatial_id is unique + tl.store(out_ptrs, out_vals_typed, out_mask) + tl.store(lora_out_ptrs, out_vals_typed, out_mask) def per_expert_lora_forward( @@ -243,13 +222,14 @@ def per_expert_lora_forward( num_experts: int, base_output: torch.Tensor = None, is_down_proj: bool = False, - topk_weights: torch.Tensor = None, - keep_experts_separate: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for per-expert LoRA computation using a 3D Triton grid: grid = (spatial, slices, loras) + Mathematically correct implementation that keeps expert outputs separate + until final reduction, matching the base Triton MoE pattern. + Args: hidden_states: [num_tokens, input_dim] where input_dim is hidden_dim for gate_up_proj or intermediate_dim for down_proj @@ -261,20 +241,15 @@ def per_expert_lora_forward( lora_ranks: [num_loras] - Rank for each LoRA lora_scalings: [num_loras] - Scaling factor for each LoRA num_experts: Total number of experts - base_output: Output tensor - shape depends on keep_experts_separate: - [num_tokens, output_dim] if False (combined) - [num_dispatched, output_dim] if True (separate) + base_output: Output tensor with shape [num_dispatched, output_dim] + Each row contains the output for one (token, expert) pair is_down_proj: Whether this is for down_proj (intermediate_dim -> hidden_dim) or gate_up_proj (hidden_dim -> intermediate_dim) - topk_weights: [num_dispatched] - Router weights for each dispatched token. - Only applied when is_down_proj=True and keep_experts_separate=False. - keep_experts_separate: If True, keeps each (token, expert) pair's output separate - (like base Triton MoE). If False, combines experts per token. Returns: tuple of: - output: Base output + LoRA delta - lora_output: Just the LoRA delta contribution + output: LoRA delta for each (token, expert) pair + lora_output: Just the LoRA delta contribution (same as output) """ # Shapes @@ -301,20 +276,10 @@ def per_expert_lora_forward( lora_ranks = lora_ranks.contiguous() lora_scalings = lora_scalings.contiguous() - # Handle topk_weights (only applied for down_proj when combining experts) - # Don't apply router weights in kernel when keeping experts separate - mul_routed_weight = is_down_proj and not keep_experts_separate - if topk_weights is not None: - topk_weights = topk_weights.contiguous() - else: - # Create dummy tensor if not provided - topk_weights = torch.empty(0, device=device, dtype=dtype) + # Router weights are always applied in the final reduction step, never in the kernel - # Determine output shape based on whether we keep experts separate - if keep_experts_separate: - output_shape = (num_dispatched, output_dim) - else: - output_shape = (num_tokens, output_dim) + # Always keep experts separate until final reduction + output_shape = (num_dispatched, output_dim) # Initialize or reuse output tensor for in-place addition if base_output is None: @@ -358,7 +323,6 @@ def per_expert_lora_forward( token_ids, # token_ids_ptr expert_ids, # expert_ids_ptr lora_ids, # lora_ids_ptr - topk_weights, # topk_weights_ptr # Dimensions input_dim, # input_dim (hidden_dim for gate_up_proj, intermediate_dim for down_proj) output_dim, # output_dim (intermediate_dim for gate_up_proj, hidden_dim for down_proj) @@ -380,12 +344,8 @@ def per_expert_lora_forward( lora_scalings, # lora_scalings_ptr # Block size (constexpr) BLOCK_SIZE=BLOCK_SIZE, - # Router weight multiplication flag - MUL_ROUTED_WEIGHT=mul_routed_weight, # Whether this is down_proj IS_DOWN_PROJ=is_down_proj, - # Whether to keep expert outputs separate - KEEP_EXPERTS_SEPARATE=keep_experts_separate, ) return output, lora_output From 3abf25e689e80bbc6ad5c52fe7b8f5b2c8fc5dcd Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 31 Dec 2025 17:11:46 -0800 Subject: [PATCH 054/150] fix mismatch types --- python/sglang/srt/lora/layers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 483155e9edd1..bb415a5d73f8 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -783,6 +783,9 @@ def _compute_lora_delta( -1 ) + # Ensure weighted_outputs has the same dtype as base_output for scatter_add_ + weighted_outputs = weighted_outputs.to(base_output.dtype) + # Scatter-add to combine experts per token # token_ids[d] tells us which token row to add weighted_outputs[d] to base_output.scatter_add_( From f8e99e52fabbb9603cb957b75e06ce0c5eadd89e Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Thu, 1 Jan 2026 09:05:38 -0800 Subject: [PATCH 055/150] remove unnecessary if --- python/sglang/srt/lora/layers.py | 75 +++++++++++++++----------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index bb415a5d73f8..9bfe47cbe067 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -749,50 +749,47 @@ def _compute_lora_delta( # Stage 3: down_proj LoRA - keep experts separate # Shape: (num_dispatched, hidden_size) - if ( - self.down_lora_a_weights is not None - and self.down_lora_b_weights is not None - ): - lora_intermediate_cache3 = torch.zeros( - (num_dispatched, hidden_size), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - _, _ = per_expert_lora_forward( - hidden_states=lora_intermediate_cache2, - lora_a_weights=self.down_lora_a_weights, - lora_b_weights=self.down_lora_b_weights, - token_ids=token_ids, - expert_ids=expert_ids, - lora_ids=lora_ids, - lora_ranks=lora_ranks, - lora_scalings=scalings, - num_experts=num_experts, - base_output=lora_intermediate_cache3, - is_down_proj=True, - ) + lora_intermediate_cache3 = torch.zeros( + (num_dispatched, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - # Stage 4: Final reduction - combine expert outputs with router weights - # Similar to moe_sum_reduce in base Triton MoE - # For each token, sum: output[t] += Σ_k (cache3[d] * topk_weights[d]) - # where d iterates over all dispatched pairs for token t + _, _ = per_expert_lora_forward( + hidden_states=lora_intermediate_cache2, + lora_a_weights=self.down_lora_a_weights, + lora_b_weights=self.down_lora_b_weights, + token_ids=token_ids, + expert_ids=expert_ids, + lora_ids=lora_ids, + lora_ranks=lora_ranks, + lora_scalings=scalings, + num_experts=num_experts, + base_output=lora_intermediate_cache3, + is_down_proj=True, + ) - # Apply router weights to each (token, expert) output - weighted_outputs = lora_intermediate_cache3 * sorted_topk_weights.unsqueeze( - -1 - ) + # Stage 4: Final reduction - combine expert outputs with router weights + # Similar to moe_sum_reduce in base Triton MoE + # For each token, sum: output[t] += Σ_k (cache3[d] * topk_weights[d]) + # where d iterates over all dispatched pairs for token t - # Ensure weighted_outputs has the same dtype as base_output for scatter_add_ - weighted_outputs = weighted_outputs.to(base_output.dtype) + # Apply router weights to each (token, expert) output + weighted_outputs = lora_intermediate_cache3 * sorted_topk_weights.unsqueeze( + -1 + ) - # Scatter-add to combine experts per token - # token_ids[d] tells us which token row to add weighted_outputs[d] to - base_output.scatter_add_( - 0, - token_ids.unsqueeze(-1).expand(-1, hidden_size), - weighted_outputs, - ) + # Ensure weighted_outputs has the same dtype as base_output for scatter_add_ + weighted_outputs = weighted_outputs.to(base_output.dtype) + + # Scatter-add to combine experts per token + # token_ids[d] tells us which token row to add weighted_outputs[d] to + base_output.scatter_add_( + 0, + token_ids.unsqueeze(-1).expand(-1, hidden_size), + weighted_outputs, + ) def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For MoE layers, tensor parallelism is typically not used From 315c64d6a2b00dcac0b427c5995526508601ead8 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 9 Jan 2026 17:54:44 -0800 Subject: [PATCH 056/150] refactor so that LoRA computations are added inside base MoE path --- python/sglang/srt/lora/layers.py | 236 ++++---- python/sglang/srt/lora/lora_moe_runners.py | 620 +++++++++++++++++++++ 2 files changed, 718 insertions(+), 138 deletions(-) create mode 100644 python/sglang/srt/lora/lora_moe_runners.py diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 9bfe47cbe067..f26b240f9464 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -24,20 +24,6 @@ from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo -from sglang.srt.utils import is_cuda, is_hip, is_cpu, cpu_has_amx_support - -# Import activation functions for LoRA (following Triton runner pattern) -_is_cuda = is_cuda() -_is_hip = is_hip() -_is_cpu = is_cpu() -_is_cpu_amx_available = cpu_has_amx_support() - -if _is_cuda: - from sgl_kernel import gelu_and_mul, silu_and_mul -elif _is_cpu and _is_cpu_amx_available: - pass -elif _is_hip: - from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul class BaseLayerWithLoRA(nn.Module): @@ -594,10 +580,14 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): class FusedMoEWithLoRA(BaseLayerWithLoRA): """ - Wrapper around FusedMoE that adds parallel LoRA computation. + Wrapper around FusedMoE that integrates LoRA into the MoE computation. + + Design: LoRA deltas are added at specific points in the MoE forward pass: + 1. After gate_up projection, BEFORE activation (halfway through) + 2. After down projection, BEFORE final reduction - Design: Base MoE and LoRA Delta run independently and merge at the end. - This preserves SGLang's existing 3-stage MoE architecture unchanged. + This follows the vLLM/HF approach where LoRA is fused into the computation + rather than computed independently and added at the end. """ def __init__( @@ -611,6 +601,7 @@ def __init__( self.gate_up_lora_b_weights = None self.down_lora_a_weights = None self.down_lora_b_weights = None + self._lora_runner = None def set_lora_info( self, @@ -626,41 +617,20 @@ def set_lora_info( self.down_lora_a_weights = down_lora_a_weights self.down_lora_b_weights = down_lora_b_weights - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): - """ - Forward pass with parallel LoRA computation. - - Flow: - 1. Base MoE forward - 2. Parallel LoRA delta computation (if enabled, added in-place) - 3. Return modified base_output - """ - hidden_states_for_lora = hidden_states.clone() - # Run base MoE - base_output = self.base_layer.forward(hidden_states, topk_output, **kwargs) - - # If LoRA is enabled, compute delta and add in-place for memory efficiency - if self.set_lora and self.gate_up_lora_a_weights is not None: - self._compute_lora_delta(hidden_states_for_lora, topk_output, base_output) - - return base_output - - def _compute_lora_delta( + def _get_lora_info( self, - hidden_states: torch.Tensor, topk_output: TopKOutput, - base_output: torch.Tensor, - ) -> None: + ): """ - Compute LoRA delta using per-expert LoRA weights and add to base_output in-place. + Build LoRAInfo for the current batch. - Dispatch tokens to experts and compute per-expert deltas. - Uses intermediate caches similar to base MoE implementation for memory efficiency. + Returns None if LoRA is not enabled or weights are not set. """ + if not self.set_lora or self.gate_up_lora_a_weights is None: + return None + + from sglang.srt.lora.lora_moe_runners import LoRAInfo from sglang.srt.lora.moe_dispatch import moe_dispatch - from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( - per_expert_lora_forward, - ) # Get dispatch info from TopKOutput topk_ids = topk_output.topk_ids # [num_tokens, top_k] @@ -674,123 +644,113 @@ def _compute_lora_delta( # Use precomputed per-token LoRA indices from forward batch lora_indices = self.lora_backend.forward_batch.token_lora_indices - num_experts = self.base_layer.num_experts - num_tokens, hidden_size = hidden_states.shape - # Dispatch tokens to experts - token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch( + token_ids, expert_ids, _, lora_ids = moe_dispatch( topk_ids=topk_ids, topk_weights=topk_weights, lora_indices=lora_indices, ) - # Get dimensions from LoRA weights - # gate_up_lora_b_weights shape: [num_loras, num_experts, gate_up_dim, max_rank] - # where gate_up_dim = 2 * intermediate_dim (gate + up combined) - _, _, gate_up_dim, _ = self.gate_up_lora_b_weights.shape - intermediate_dim = gate_up_dim // 2 # After activation, dimension halves - - # Get number of dispatched (token, expert) pairs - num_dispatched = token_ids.shape[0] - - # Keep expert outputs separate until final reduction - # Stage 1: gate_up_proj LoRA - keep experts separate - # Shape: (num_dispatched, gate_up_dim) where each row is one (token, expert) pair - lora_intermediate_cache1 = torch.zeros( - (num_dispatched, gate_up_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - _, _ = per_expert_lora_forward( - hidden_states=hidden_states, - lora_a_weights=self.gate_up_lora_a_weights, - lora_b_weights=self.gate_up_lora_b_weights, + return LoRAInfo( + gate_up_lora_a_weights=self.gate_up_lora_a_weights, + gate_up_lora_b_weights=self.gate_up_lora_b_weights, + down_lora_a_weights=self.down_lora_a_weights, + down_lora_b_weights=self.down_lora_b_weights, token_ids=token_ids, expert_ids=expert_ids, lora_ids=lora_ids, lora_ranks=lora_ranks, lora_scalings=scalings, - num_experts=num_experts, - base_output=lora_intermediate_cache1, - is_down_proj=False, + num_experts=self.base_layer.num_experts, ) - # Stage 2: Apply activation to each (token, expert) pair separately - # Output shape: (num_dispatched, intermediate_dim) - dimension halves due to SiLU/GeGLU - lora_intermediate_cache2 = torch.zeros( - (num_dispatched, intermediate_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): + """ + Forward pass with integrated LoRA computation. - activation = self.base_layer.moe_runner_config.activation - - if activation == "silu": - if _is_cuda: - silu_and_mul(lora_intermediate_cache1, lora_intermediate_cache2) - elif _is_hip: - vllm_ops.silu_and_mul( - lora_intermediate_cache2, lora_intermediate_cache1 - ) - else: - raise ValueError(f"Unsupported activation: {activation=}") - elif activation == "gelu": - if _is_cuda: - gelu_and_mul(lora_intermediate_cache1, lora_intermediate_cache2) - elif _is_hip: - vllm_ops.gelu_and_mul( - lora_intermediate_cache2, lora_intermediate_cache1 - ) - else: - raise ValueError(f"Unsupported activation: {activation=}") - else: - raise ValueError(f"Unsupported activation: {activation=}") + LoRA deltas are added at the correct points inside the MoE computation: + 1. After gate_up projection, before activation + 2. After down projection, before final reduction + """ + # If LoRA is not enabled, just run base MoE + if not self.set_lora or self.gate_up_lora_a_weights is None: + return self.base_layer.forward(hidden_states, topk_output, **kwargs) - # Stage 3: down_proj LoRA - keep experts separate - # Shape: (num_dispatched, hidden_size) + # Build LoRA info for this batch + lora_info = self._get_lora_info(topk_output) - lora_intermediate_cache3 = torch.zeros( - (num_dispatched, hidden_size), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + # For now, we use the integrated runner approach only for Triton backend + # This wraps the base layer's forward with LoRA integration + return self._forward_with_lora(hidden_states, topk_output, lora_info, **kwargs) - _, _ = per_expert_lora_forward( - hidden_states=lora_intermediate_cache2, - lora_a_weights=self.down_lora_a_weights, - lora_b_weights=self.down_lora_b_weights, - token_ids=token_ids, - expert_ids=expert_ids, - lora_ids=lora_ids, - lora_ranks=lora_ranks, - lora_scalings=scalings, - num_experts=num_experts, - base_output=lora_intermediate_cache3, - is_down_proj=True, + def _forward_with_lora( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + lora_info, + **kwargs, + ): + """ + Run MoE forward with LoRA integration at the correct points. + + This method hooks into the base layer's computation to add LoRA deltas + at the right stages. + """ + from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA + from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + + # Get the base layer's dispatch and combine logic + base_layer = self.base_layer + origin_hidden_states_dim = hidden_states.shape[-1] + + # Dispatch tokens + dispatch_output = base_layer.dispatcher.dispatch( + hidden_states=hidden_states, topk_output=topk_output ) - # Stage 4: Final reduction - combine expert outputs with router weights - # Similar to moe_sum_reduce in base Triton MoE - # For each token, sum: output[t] += Σ_k (cache3[d] * topk_weights[d]) - # where d iterates over all dispatched pairs for token t + # Create LoRA-aware runner if not already created + if self._lora_runner is None: + self._lora_runner = TritonRunnerCoreWithLoRA(base_layer.moe_runner_config) + + # Build quant info (for unquantized, this is straightforward) + quant_info = TritonMoeQuantInfo( + w13_weight=base_layer.w13_weight, + w2_weight=base_layer.w2_weight, + b13=getattr(base_layer, "w13_weight_bias", None), + b2=getattr(base_layer, "w2_weight_bias", None), + ) - # Apply router weights to each (token, expert) output - weighted_outputs = lora_intermediate_cache3 * sorted_topk_weights.unsqueeze( - -1 + # Get running state (includes config from pre-permute) + from sglang.srt.layers.moe.moe_runner.triton import ( + pre_permute_standard_to_triton, ) + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput - # Ensure weighted_outputs has the same dtype as base_output for scatter_add_ - weighted_outputs = weighted_outputs.to(base_output.dtype) + running_state = {} + runner_input = pre_permute_standard_to_triton( + dispatch_output, quant_info, base_layer.moe_runner_config, running_state + ) - # Scatter-add to combine experts per token - # token_ids[d] tells us which token row to add weighted_outputs[d] to - base_output.scatter_add_( - 0, - token_ids.unsqueeze(-1).expand(-1, hidden_size), - weighted_outputs, + # Run with LoRA integration + runner_output = self._lora_runner.run( + runner_input, quant_info, running_state, lora_info ) + # Combine and return + combine_input = StandardCombineInput(hidden_states=runner_output.hidden_states) + + final_hidden_states = base_layer.dispatcher.combine(combine_input=combine_input) + final_hidden_states = final_hidden_states[ + ..., :origin_hidden_states_dim + ].contiguous() + + if base_layer.reduce_results and ( + base_layer.moe_tp_size > 1 or base_layer.moe_ep_size > 1 + ): + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states + def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): # For MoE layers, tensor parallelism is typically not used # Return weights unchanged diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py new file mode 100644 index 000000000000..d81034d9d88a --- /dev/null +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -0,0 +1,620 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LoRA-aware MoE runners that integrate LoRA deltas into the MoE computation. + +The key insight is that LoRA deltas must be added at specific points: +1. After gate_up projection, BEFORE activation (halfway through) +2. After down projection, BEFORE final reduction (at the end) + +This differs from computing LoRA independently and adding at the very end. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch +import triton.language as tl + +from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import ( + TritonMoeQuantInfo, + TritonRunnerCore, + TritonRunnerInput, + TritonRunnerOutput, +) +from sglang.srt.utils import is_cuda, is_hip + +_is_hip = is_hip() +_is_cuda = is_cuda() + +if _is_cuda or _is_hip: + from sgl_kernel import gelu_and_mul, silu_and_mul + + +@dataclass +class LoRAInfo: + """LoRA weights and dispatch info for MoE computation.""" + + # LoRA weights: [num_loras, num_experts, dim1, dim2] + gate_up_lora_a_weights: ( + torch.Tensor + ) # [num_loras, num_experts, max_rank, hidden_dim] + gate_up_lora_b_weights: ( + torch.Tensor + ) # [num_loras, num_experts, gate_up_dim, max_rank] + down_lora_a_weights: ( + torch.Tensor + ) # [num_loras, num_experts, max_rank, intermediate_dim] + down_lora_b_weights: torch.Tensor # [num_loras, num_experts, hidden_dim, max_rank] + + # Dispatch info (sorted by expert) + token_ids: torch.Tensor # [num_dispatched] - original token indices + expert_ids: torch.Tensor # [num_dispatched] - expert IDs + lora_ids: torch.Tensor # [num_dispatched] - LoRA adapter IDs + + # LoRA config per adapter + lora_ranks: torch.Tensor # [num_loras] + lora_scalings: torch.Tensor # [num_loras] + + num_experts: int + + +class TritonRunnerCoreWithLoRA(TritonRunnerCore): + """ + LoRA-aware wrapper around TritonRunnerCore. + + Integrates LoRA deltas at the correct points in the MoE forward pass: + 1. Base gate_up projection + LoRA gate_up delta -> activation + 2. Base down projection + LoRA down delta -> final reduction + + This follows the vLLM/HF approach where LoRA is fused into the computation + rather than computed independently. + """ + + def __init__(self, config: MoeRunnerConfig): + super().__init__(config) + + def run( + self, + runner_input: TritonRunnerInput, + quant_info: TritonMoeQuantInfo, + running_state: dict, + lora_info: Optional[LoRAInfo] = None, + ) -> TritonRunnerOutput: + """ + Run MoE with integrated LoRA computation. + + This method extends TritonRunnerCore.run() by inserting LoRA delta + computations at the correct points in the MoE forward pass. + + Args: + runner_input: Standard Triton runner input + quant_info: Quantization info for base weights + running_state: Running state dict + lora_info: Optional LoRA weights and dispatch info + + Returns: + TritonRunnerOutput with combined base + LoRA output + """ + # If no LoRA, use base implementation + if lora_info is None: + return super().run(runner_input, quant_info, running_state) + + # Extract common variables + hidden_states = runner_input.hidden_states + topk_weights = runner_input.topk_weights + topk_ids = runner_input.topk_ids + sorted_token_ids = runner_input.sorted_token_ids + expert_ids = runner_input.expert_ids + num_tokens_post_padded = runner_input.num_tokens_post_padded + + w13 = quant_info.w13_weight + w2 = quant_info.w2_weight + b13 = quant_info.b13 + b2 = quant_info.b2 + a13_scale = quant_info.a13_scale + a2_scale = quant_info.a2_scale + w13_scale = quant_info.w13_scale + w2_scale = quant_info.w2_scale + w13_zp = quant_info.w13_zp + w2_zp = quant_info.w2_zp + block_shape = quant_info.block_shape + per_channel_quant = quant_info.per_channel_quant + use_fp8_w8a8 = quant_info.use_fp8_w8a8 + use_int8_w8a8 = quant_info.use_int8_w8a8 + use_int8_w8a16 = quant_info.use_int8_w8a16 + use_int4_w4a16 = quant_info.use_int4_w4a16 + + activation = self.config.activation + no_combine = self.config.no_combine + inplace = self.config.inplace + gemm1_alpha = self.config.gemm1_alpha + gemm1_limit = self.config.gemm1_clamp_limit + routed_scaling_factor = self.config.routed_scaling_factor + apply_router_weight_on_input = self.config.apply_router_weight_on_input + + assert self.config.is_gated, "Only gated MoEs are supported for Triton runner" + + M = hidden_states.shape[0] + E, N, _ = w13.shape + compute_type = ( + tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + ) + + # Import functions needed for MoE computation + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + invoke_fused_moe_kernel, + moe_sum_reduce_torch_compile, + moe_sum_reduce_triton, + swiglu_with_alpha_and_limit, + ) + + hidden_states = runner_input.hidden_states + topk_weights = runner_input.topk_weights + topk_ids = runner_input.topk_ids + sorted_token_ids = runner_input.sorted_token_ids + expert_ids = runner_input.expert_ids + num_tokens_post_padded = runner_input.num_tokens_post_padded + + w13 = quant_info.w13_weight + w2 = quant_info.w2_weight + b13 = quant_info.b13 + b2 = quant_info.b2 + a13_scale = quant_info.a13_scale + a2_scale = quant_info.a2_scale + w13_scale = quant_info.w13_scale + w2_scale = quant_info.w2_scale + w13_zp = quant_info.w13_zp + w2_zp = quant_info.w2_zp + block_shape = quant_info.block_shape + per_channel_quant = quant_info.per_channel_quant + use_fp8_w8a8 = quant_info.use_fp8_w8a8 + use_int8_w8a8 = quant_info.use_int8_w8a8 + use_int8_w8a16 = quant_info.use_int8_w8a16 + use_int4_w4a16 = quant_info.use_int4_w4a16 + + activation = self.config.activation + no_combine = self.config.no_combine + inplace = self.config.inplace + gemm1_alpha = self.config.gemm1_alpha + gemm1_limit = self.config.gemm1_clamp_limit + routed_scaling_factor = self.config.routed_scaling_factor + apply_router_weight_on_input = self.config.apply_router_weight_on_input + + assert self.config.is_gated, "Only gated MoEs are supported for Triton runner" + + M = hidden_states.shape[0] + E, N, _ = w13.shape + compute_type = ( + tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + ) + + # ============================================================ + # Stage 1: Gate/Up projection (base) + # ============================================================ + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + invoke_fused_moe_kernel( + hidden_states, + w13, + b13, + intermediate_cache1, + a13_scale, + w13_scale, + w13_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + topk_ids.shape[1], + running_state["config"], + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ) + + # ============================================================ + # Stage 1.5: Add LoRA gate_up delta BEFORE activation + # ============================================================ + self._add_lora_gate_up_delta( + hidden_states=hidden_states, + intermediate_cache=intermediate_cache1, + topk_ids=topk_ids, + lora_info=lora_info, + ) + + # ============================================================ + # Stage 2: Activation (SiLU or GELU) + # ============================================================ + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if activation == "silu": + if gemm1_alpha is not None: + assert gemm1_limit is not None + intermediate_cache2 = swiglu_with_alpha_and_limit( + intermediate_cache1.view(-1, N), + gemm1_alpha, + gemm1_limit, + ) + elif _is_cuda or _is_hip: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + raise ValueError(f"Unsupported platform for activation: {activation=}") + elif activation == "gelu": + assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" + assert gemm1_limit is None, "gemm1_limit is not supported for gelu" + if _is_cuda or _is_hip: + gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + raise ValueError(f"Unsupported platform for activation: {activation=}") + else: + raise ValueError(f"Unsupported activation: {activation=}") + + # ============================================================ + # Stage 3: Down projection (base) + # ============================================================ + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if no_combine: + assert not inplace + out_hidden_states = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + elif inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + b2, + ( + intermediate_cache3 + if not no_combine and topk_ids.shape[1] != 1 + else out_hidden_states.unsqueeze(0) + ), + a2_scale, + w2_scale, + w2_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + running_state["config"], + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ) + + # ============================================================ + # Stage 3.5: Add LoRA down delta BEFORE final reduction + # ============================================================ + self._add_lora_down_delta( + intermediate_input=intermediate_cache2, + intermediate_cache=intermediate_cache3, + topk_ids=topk_ids, + topk_weights=topk_weights, + apply_router_weight_on_input=apply_router_weight_on_input, + lora_info=lora_info, + ) + + # ============================================================ + # Stage 4: Final reduction (sum across top_k) + # ============================================================ + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 + + if no_combine: + pass + elif _is_cuda: + if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0: + pass # we write directly into out_hidden_states + elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0: + torch.add( + intermediate_cache3[:, 0], + intermediate_cache3[:, 1], + out=out_hidden_states, + ).squeeze(dim=1) + else: + if M <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + routed_scaling_factor, + ) + elif _is_hip: + from vllm import _custom_ops as vllm_ops + + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + else: + from vllm import _custom_ops as vllm_ops + + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states, + ) + + return TritonRunnerOutput( + hidden_states=out_hidden_states, + ) + + def _add_lora_gate_up_delta( + self, + hidden_states: torch.Tensor, # [M, hidden_dim] + intermediate_cache: torch.Tensor, # [M, top_k, gate_up_dim] + topk_ids: torch.Tensor, # [M, top_k] + lora_info: LoRAInfo, + ) -> None: + """ + Add LoRA gate_up delta to intermediate_cache in-place. + + For each (token, expert) pair, computes: + delta = scaling * B @ (A @ hidden_states[token]) + and adds it to intermediate_cache[token, k] where k is the top_k index. + """ + from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( + per_expert_lora_forward, + ) + + M, top_k, gate_up_dim = intermediate_cache.shape + num_dispatched = lora_info.token_ids.shape[0] + + # Compute LoRA delta for each (token, expert) pair + # Output shape: [num_dispatched, gate_up_dim] + lora_delta = torch.zeros( + (num_dispatched, gate_up_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + _, lora_delta = per_expert_lora_forward( + hidden_states=hidden_states, + lora_a_weights=lora_info.gate_up_lora_a_weights, + lora_b_weights=lora_info.gate_up_lora_b_weights, + token_ids=lora_info.token_ids, + expert_ids=lora_info.expert_ids, + lora_ids=lora_info.lora_ids, + lora_ranks=lora_info.lora_ranks, + lora_scalings=lora_info.lora_scalings, + num_experts=lora_info.num_experts, + base_output=lora_delta, + is_down_proj=False, + ) + + # Add delta to intermediate_cache at the right positions + # We need to map from dispatched indices back to (token, top_k_idx) pairs + self._scatter_add_to_topk_cache( + lora_delta=lora_delta, + intermediate_cache=intermediate_cache, + token_ids=lora_info.token_ids, + expert_ids=lora_info.expert_ids, + topk_ids=topk_ids, + ) + + def _add_lora_down_delta( + self, + intermediate_input: torch.Tensor, # [M * top_k, intermediate_dim] + intermediate_cache: torch.Tensor, # [M, top_k, hidden_dim] + topk_ids: torch.Tensor, # [M, top_k] + topk_weights: torch.Tensor, # [M, top_k] + apply_router_weight_on_input: bool, + lora_info: LoRAInfo, + ) -> None: + """ + Add LoRA down delta to intermediate_cache in-place. + + For each (token, expert) pair, computes: + delta = scaling * B @ (A @ intermediate_input[dispatched_idx]) + and adds it to intermediate_cache[token, k]. + """ + from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( + per_expert_lora_forward, + ) + + M, top_k, hidden_dim = intermediate_cache.shape + + # Build indices to gather from intermediate_input + # For each dispatched (token, expert) pair, find which top_k slot it corresponds to + lora_intermediate_input = self._gather_dispatched_inputs( + intermediate_input=intermediate_input, + token_ids=lora_info.token_ids, + expert_ids=lora_info.expert_ids, + topk_ids=topk_ids, + M=M, + top_k=top_k, + ) + + # Compute LoRA delta for down projection + num_dispatched = lora_info.token_ids.shape[0] + lora_delta = torch.zeros( + (num_dispatched, hidden_dim), + dtype=intermediate_input.dtype, + device=intermediate_input.device, + ) + + # IMPORTANT: For down_proj, the input (lora_intermediate_input) is already + # gathered and indexed by dispatched position (0, 1, ..., num_dispatched-1), + # not by original token position. So we pass identity indices for token_ids + # to make the kernel read from the correct positions. + dispatched_indices = torch.arange( + num_dispatched, + device=lora_info.token_ids.device, + dtype=lora_info.token_ids.dtype, + ) + + _, lora_delta = per_expert_lora_forward( + hidden_states=lora_intermediate_input, + lora_a_weights=lora_info.down_lora_a_weights, + lora_b_weights=lora_info.down_lora_b_weights, + token_ids=dispatched_indices, # Use identity indices, not original token_ids + expert_ids=lora_info.expert_ids, + lora_ids=lora_info.lora_ids, + lora_ranks=lora_info.lora_ranks, + lora_scalings=lora_info.lora_scalings, + num_experts=lora_info.num_experts, + base_output=lora_delta, + is_down_proj=True, + ) + + # Apply router weights if not already applied to input + # This matches the base MoE behavior + if not apply_router_weight_on_input: + # Get router weights for each dispatched pair + router_weights = self._gather_router_weights( + topk_weights=topk_weights, + token_ids=lora_info.token_ids, + expert_ids=lora_info.expert_ids, + topk_ids=topk_ids, + ) + lora_delta = lora_delta * router_weights.unsqueeze(-1) + + # Add delta to intermediate_cache + self._scatter_add_to_topk_cache( + lora_delta=lora_delta, + intermediate_cache=intermediate_cache, + token_ids=lora_info.token_ids, + expert_ids=lora_info.expert_ids, + topk_ids=topk_ids, + ) + + def _scatter_add_to_topk_cache( + self, + lora_delta: torch.Tensor, # [num_dispatched, dim] + intermediate_cache: torch.Tensor, # [M, top_k, dim] + token_ids: torch.Tensor, # [num_dispatched] + expert_ids: torch.Tensor, # [num_dispatched] + topk_ids: torch.Tensor, # [M, top_k] + ) -> None: + """ + Scatter-add lora_delta to intermediate_cache based on dispatch info. + + For each dispatched index d: + - token_id = token_ids[d] + - expert_id = expert_ids[d] + - Find k such that topk_ids[token_id, k] == expert_id + - intermediate_cache[token_id, k] += lora_delta[d] + """ + M, top_k, dim = intermediate_cache.shape + + # Find the top_k index for each dispatched pair + # topk_ids[token_ids] gives [num_dispatched, top_k] + # We need to find which column matches expert_ids + expanded_topk = topk_ids[token_ids] # [num_dispatched, top_k] + expert_mask = expanded_topk == expert_ids.unsqueeze( + 1 + ) # [num_dispatched, top_k] + + # Get the k index for each dispatched pair + k_indices = expert_mask.int().argmax(dim=1) # [num_dispatched] + + # Compute flat indices into intermediate_cache viewed as [M * top_k, dim] + flat_indices = token_ids * top_k + k_indices # [num_dispatched] + + # Reshape cache for scatter_add + cache_flat = intermediate_cache.view(M * top_k, dim) + + # Scatter add + cache_flat.scatter_add_( + 0, + flat_indices.unsqueeze(-1).expand(-1, dim), + lora_delta.to(cache_flat.dtype), + ) + + def _gather_dispatched_inputs( + self, + intermediate_input: torch.Tensor, # [M * top_k, dim] + token_ids: torch.Tensor, # [num_dispatched] + expert_ids: torch.Tensor, # [num_dispatched] + topk_ids: torch.Tensor, # [M, top_k] + M: int, + top_k: int, + ) -> torch.Tensor: + """ + Gather intermediate inputs for dispatched (token, expert) pairs. + + Returns tensor of shape [num_dispatched, dim]. + """ + # Find which top_k slot each dispatched pair corresponds to + expanded_topk = topk_ids[token_ids] # [num_dispatched, top_k] + expert_mask = expanded_topk == expert_ids.unsqueeze(1) + k_indices = expert_mask.int().argmax(dim=1) # [num_dispatched] + + # Compute flat indices + flat_indices = token_ids * top_k + k_indices + + # Gather + return intermediate_input[flat_indices] + + def _gather_router_weights( + self, + topk_weights: torch.Tensor, # [M, top_k] + token_ids: torch.Tensor, # [num_dispatched] + expert_ids: torch.Tensor, # [num_dispatched] + topk_ids: torch.Tensor, # [M, top_k] + ) -> torch.Tensor: + """ + Gather router weights for dispatched (token, expert) pairs. + + Returns tensor of shape [num_dispatched]. + """ + # Find which top_k slot each dispatched pair corresponds to + expanded_topk = topk_ids[token_ids] # [num_dispatched, top_k] + expert_mask = expanded_topk == expert_ids.unsqueeze(1) + k_indices = expert_mask.int().argmax(dim=1) # [num_dispatched] + + # Gather weights + expanded_weights = topk_weights[token_ids] # [num_dispatched, top_k] + return expanded_weights[ + torch.arange(len(token_ids), device=topk_weights.device), k_indices + ] From 099aa82ac16634a7c6845a252732896f97148e82 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 14:27:58 -0500 Subject: [PATCH 057/150] refactor to utilize vLLM kernel --- python/sglang/srt/lora/layers.py | 13 +- python/sglang/srt/lora/lora.py | 1 - python/sglang/srt/lora/lora_manager.py | 11 +- python/sglang/srt/lora/lora_moe_runners.py | 294 +++----- python/sglang/srt/lora/triton_ops/__init__.py | 2 + .../lora/triton_ops/fused_moe_lora_kernel.py | 677 ++++++++++++++++++ 6 files changed, 808 insertions(+), 190 deletions(-) create mode 100644 python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index f26b240f9464..56ab959a7f0f 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -594,8 +594,10 @@ def __init__( self, base_layer: nn.Module, lora_backend: BaseLoRABackend, + adapter_enabled: torch.Tensor = None, ): super().__init__(base_layer, lora_backend) + self.adapter_enabled = adapter_enabled # LoRA tensors will be set by LoRAManager self.gate_up_lora_a_weights = None self.gate_up_lora_b_weights = None @@ -641,6 +643,9 @@ def _get_lora_info( lora_ranks = batch_info.lora_ranks # [num_loras] scalings = batch_info.scalings # [num_loras] + # Get adapter_enabled from layer instance, slice to match current batch + adapter_enabled = self.adapter_enabled[:len(lora_ranks)] if self.adapter_enabled is not None else torch.zeros(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) + # Use precomputed per-token LoRA indices from forward batch lora_indices = self.lora_backend.forward_batch.token_lora_indices @@ -661,6 +666,7 @@ def _get_lora_info( lora_ids=lora_ids, lora_ranks=lora_ranks, lora_scalings=scalings, + adapter_enabled=adapter_enabled, num_experts=self.base_layer.num_experts, ) @@ -763,7 +769,7 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def get_lora_layer( - layer: nn.Module, lora_backend: BaseLoRABackend + layer: nn.Module, lora_backend: BaseLoRABackend, adapter_enabled: torch.Tensor = None ) -> BaseLayerWithLoRA: # FusedMoE is now imported at the top of the file # FusedMoEWithLoRA is now defined in this file @@ -780,6 +786,9 @@ def get_lora_layer( } for src_layer_type, lora_layer_type in supported_layer_types.items(): if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck - ret = lora_layer_type(layer, lora_backend) + if src_layer_type == FusedMoE: + ret = lora_layer_type(layer, lora_backend, adapter_enabled) + else: + ret = lora_layer_type(layer, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 12c813baeb20..005a21a41516 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -46,7 +46,6 @@ def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig): class LoRAAdapter(nn.Module): - def __init__( self, uid: str, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index e04e084912a8..e243e7e0b87a 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -75,6 +75,9 @@ def __init__( self.tp_rank: int = tp_rank self.lora_added_tokens_size: Optional[int] = None + # Track which LoRA adapters are enabled (loaded/available) + self.adapter_enabled = torch.zeros(max_loras_per_batch + 1, dtype=torch.int32, device=self.device) + # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -141,6 +144,9 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: # keep metadata for displayed messages self.lora_refs[lora_ref.lora_id] = lora_ref self.num_pinned_loras += int(lora_ref.pinned) + + # Mark this adapter as enabled + self.adapter_enabled[lora_ref.lora_id] = 1 except Exception as e: return self.create_lora_update_result( success=False, @@ -202,6 +208,9 @@ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: del self.loras[lora_ref.lora_id] del self.lora_refs[lora_ref.lora_id] self.num_pinned_loras -= int(lora_ref.pinned) + + # Mark this adapter as disabled + self.adapter_enabled[lora_ref.lora_id] = 0 except Exception as e: return self.create_lora_update_result( success=False, @@ -530,7 +539,7 @@ def init_memory_pool(self): def set_lora_module(self, module_name, module): """Wrap any module (standard or MoE) with LoRA support.""" - lora_module = get_lora_layer(module, self.lora_backend) + lora_module = get_lora_layer(module, self.lora_backend, self.adapter_enabled) replace_submodule(self.base_model, module_name, lora_module) return lora_module diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index d81034d9d88a..532cc011d014 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -69,6 +69,7 @@ class LoRAInfo: # LoRA config per adapter lora_ranks: torch.Tensor # [num_loras] lora_scalings: torch.Tensor # [num_loras] + adapter_enabled: torch.Tensor # [num_loras] - which adapters are enabled num_experts: int @@ -212,6 +213,10 @@ def run( dtype=hidden_states.dtype, ) + # output shape: [M, top_k, N] and in original token order since we do not pass in c_sorted=True. If we + # want to get the output in sorted order by expert, we can pass in c_sorted=True. + # TODO: determine whether we should pass in c_sorted=True. That will make it less readable and different from base run method. + # but we won't have to scatter the lora delta to the correct positions before adding them to this base output. invoke_fused_moe_kernel( hidden_states, w13, @@ -244,6 +249,7 @@ def run( hidden_states=hidden_states, intermediate_cache=intermediate_cache1, topk_ids=topk_ids, + topk_weights=topk_weights, lora_info=lora_info, ) @@ -299,6 +305,12 @@ def run( else: out_hidden_states = torch.empty_like(hidden_states) + + # output shape: [M, hidden_dim] and in original token order since we do not pass in c_sorted=True. If we + # want to get the output in sorted order by expert, we can pass in c_sorted=True. We use no_combine=False because the next + # LoRA computation requires input to be [M, hidden_dim] + # TODO: determine whether we should pass in c_sorted=True. That will make it less readable and different from base run method. + # but we won't have to scatter the lora delta to the correct positions before adding them to this base output. invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -328,9 +340,10 @@ def run( block_shape=block_shape, ) - # ============================================================ + # ============================================================ # Stage 3.5: Add LoRA down delta BEFORE final reduction # ============================================================ + # intermediate_cache2 is in the original token order and token-major order. self._add_lora_down_delta( intermediate_input=intermediate_cache2, intermediate_cache=intermediate_cache3, @@ -340,6 +353,9 @@ def run( lora_info=lora_info, ) + # we still need to combine the output + + # ============================================================ # Stage 4: Final reduction (sum across top_k) # ============================================================ @@ -394,6 +410,7 @@ def _add_lora_gate_up_delta( hidden_states: torch.Tensor, # [M, hidden_dim] intermediate_cache: torch.Tensor, # [M, top_k, gate_up_dim] topk_ids: torch.Tensor, # [M, top_k] + topk_weights: torch.Tensor, # [M, top_k] lora_info: LoRAInfo, ) -> None: """ @@ -403,44 +420,56 @@ def _add_lora_gate_up_delta( delta = scaling * B @ (A @ hidden_states[token]) and adds it to intermediate_cache[token, k] where k is the top_k index. """ - from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( - per_expert_lora_forward, - ) + from sglang.srt.lora.triton_ops import fused_moe_lora M, top_k, gate_up_dim = intermediate_cache.shape num_dispatched = lora_info.token_ids.shape[0] - # Compute LoRA delta for each (token, expert) pair - # Output shape: [num_dispatched, gate_up_dim] - lora_delta = torch.zeros( - (num_dispatched, gate_up_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - _, lora_delta = per_expert_lora_forward( - hidden_states=hidden_states, - lora_a_weights=lora_info.gate_up_lora_a_weights, - lora_b_weights=lora_info.gate_up_lora_b_weights, - token_ids=lora_info.token_ids, - expert_ids=lora_info.expert_ids, + # Compute LoRA delta where intermediate_cache needs to be [M, top_k, gate_up_dim] in original token order. + # Output shape: [M, top_k, gate_up_dim] + # Hidden_states shape: [M, hidden_dim] (handles token duplication internally) + + num_tokens_post_padded_formatted = torch.tensor([num_dispatched], dtype=torch.int32, device=hidden_states.device) + actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) + + # Handle multi-LoRA: stack weights for all loaded LoRAs + # lora_info.gate_up_lora_a_weights shape: [num_loras, num_experts, max_rank, hidden_dim] + # Note: LoRA scaling factors (lora_info.lora_scalings) are already applied to weights during loading + max_loras = len(lora_info.lora_ranks) + + lora_a_stacked = [lora_info.gate_up_lora_a_weights[i] for i in range(max_loras)] + lora_b_stacked = [lora_info.gate_up_lora_b_weights[i] for i in range(max_loras)] + + fused_moe_lora( + output=intermediate_cache, + qcurr_hidden_states=hidden_states, + lora_a_stacked=lora_a_stacked, + lora_b_stacked=lora_b_stacked, + topk_weights=topk_weights, # Use actual routing weights + sorted_token_ids=lora_info.token_ids.unsqueeze(0), + expert_ids=lora_info.expert_ids.unsqueeze(0), + num_tokens_post_padded=num_tokens_post_padded_formatted, + max_lora_rank=actual_max_lora_rank, + top_k_num=top_k, lora_ids=lora_info.lora_ids, - lora_ranks=lora_info.lora_ranks, - lora_scalings=lora_info.lora_scalings, - num_experts=lora_info.num_experts, - base_output=lora_delta, - is_down_proj=False, + adapter_enabled=lora_info.adapter_enabled, + shrink_block_size_m=64, + shrink_block_size_n=64, + shrink_block_size_k=64, + shrink_group_size_m=8, + shrink_num_warps=4, + shrink_num_stages=2, + shrink_split_k=1, + expand_block_size_m=64, + expand_block_size_n=64, + expand_block_size_k=64, + expand_group_size_m=8, + expand_num_warps=4, + expand_num_stages=2, + expand_split_k=1, ) - # Add delta to intermediate_cache at the right positions - # We need to map from dispatched indices back to (token, top_k_idx) pairs - self._scatter_add_to_topk_cache( - lora_delta=lora_delta, - intermediate_cache=intermediate_cache, - token_ids=lora_info.token_ids, - expert_ids=lora_info.expert_ids, - topk_ids=topk_ids, - ) + def _add_lora_down_delta( self, @@ -458,163 +487,56 @@ def _add_lora_down_delta( delta = scaling * B @ (A @ intermediate_input[dispatched_idx]) and adds it to intermediate_cache[token, k]. """ - from sglang.srt.lora.triton_ops.per_expert_lora_moe import ( - per_expert_lora_forward, - ) + from sglang.srt.lora.triton_ops import fused_moe_lora M, top_k, hidden_dim = intermediate_cache.shape - # Build indices to gather from intermediate_input - # For each dispatched (token, expert) pair, find which top_k slot it corresponds to - lora_intermediate_input = self._gather_dispatched_inputs( - intermediate_input=intermediate_input, - token_ids=lora_info.token_ids, - expert_ids=lora_info.expert_ids, - topk_ids=topk_ids, - M=M, - top_k=top_k, - ) - - # Compute LoRA delta for down projection - num_dispatched = lora_info.token_ids.shape[0] - lora_delta = torch.zeros( - (num_dispatched, hidden_dim), - dtype=intermediate_input.dtype, - device=intermediate_input.device, - ) - - # IMPORTANT: For down_proj, the input (lora_intermediate_input) is already - # gathered and indexed by dispatched position (0, 1, ..., num_dispatched-1), - # not by original token position. So we pass identity indices for token_ids - # to make the kernel read from the correct positions. - dispatched_indices = torch.arange( - num_dispatched, - device=lora_info.token_ids.device, - dtype=lora_info.token_ids.dtype, - ) - - _, lora_delta = per_expert_lora_forward( - hidden_states=lora_intermediate_input, - lora_a_weights=lora_info.down_lora_a_weights, - lora_b_weights=lora_info.down_lora_b_weights, - token_ids=dispatched_indices, # Use identity indices, not original token_ids - expert_ids=lora_info.expert_ids, + # intermediate_input is the input from the previous stage and is in original token order. + + # Data format adaptation for vLLM kernel + num_dispatched_down = lora_info.token_ids.shape[0] + num_tokens_post_padded_formatted = torch.tensor([num_dispatched_down], dtype=torch.int32, device=intermediate_input.device) + actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) + + # Handle multi-LoRA: stack weights for all loaded LoRAs + # lora_info.down_lora_a_weights shape: [num_loras, num_experts, max_rank, intermediate_dim] + # Note: LoRA scaling factors (lora_info.lora_scalings) are already applied to weights during loading + max_loras = len(lora_info.lora_ranks) + + # Validate weight dimensions match expectations + assert lora_info.down_lora_a_weights.shape[0] == max_loras, f"Expected {max_loras} LoRAs, got {lora_info.down_lora_a_weights.shape[0]}" + assert lora_info.adapter_enabled.shape[0] >= max_loras, f"adapter_enabled too small: {lora_info.adapter_enabled.shape[0]} < {max_loras}" + + lora_a_stacked = [lora_info.down_lora_a_weights[i] for i in range(max_loras)] + lora_b_stacked = [lora_info.down_lora_b_weights[i] for i in range(max_loras)] + + fused_moe_lora( + output=intermediate_cache, + qcurr_hidden_states=intermediate_input, + lora_a_stacked=lora_a_stacked, + lora_b_stacked=lora_b_stacked, + topk_weights=topk_weights, # Use the routing weights passed to this function + sorted_token_ids=lora_info.token_ids.unsqueeze(0), + expert_ids=lora_info.expert_ids.unsqueeze(0), + num_tokens_post_padded=num_tokens_post_padded_formatted, + max_lora_rank=actual_max_lora_rank, + top_k_num=top_k, lora_ids=lora_info.lora_ids, - lora_ranks=lora_info.lora_ranks, - lora_scalings=lora_info.lora_scalings, - num_experts=lora_info.num_experts, - base_output=lora_delta, - is_down_proj=True, - ) - - # Apply router weights if not already applied to input - # This matches the base MoE behavior - if not apply_router_weight_on_input: - # Get router weights for each dispatched pair - router_weights = self._gather_router_weights( - topk_weights=topk_weights, - token_ids=lora_info.token_ids, - expert_ids=lora_info.expert_ids, - topk_ids=topk_ids, - ) - lora_delta = lora_delta * router_weights.unsqueeze(-1) - - # Add delta to intermediate_cache - self._scatter_add_to_topk_cache( - lora_delta=lora_delta, - intermediate_cache=intermediate_cache, - token_ids=lora_info.token_ids, - expert_ids=lora_info.expert_ids, - topk_ids=topk_ids, + adapter_enabled=lora_info.adapter_enabled, + shrink_block_size_m=64, + shrink_block_size_n=64, + shrink_block_size_k=64, + shrink_group_size_m=8, + shrink_num_warps=4, + shrink_num_stages=2, + shrink_split_k=1, + expand_block_size_m=64, + expand_block_size_n=64, + expand_block_size_k=64, + expand_group_size_m=8, + expand_num_warps=4, + expand_num_stages=2, + expand_split_k=1, ) + - def _scatter_add_to_topk_cache( - self, - lora_delta: torch.Tensor, # [num_dispatched, dim] - intermediate_cache: torch.Tensor, # [M, top_k, dim] - token_ids: torch.Tensor, # [num_dispatched] - expert_ids: torch.Tensor, # [num_dispatched] - topk_ids: torch.Tensor, # [M, top_k] - ) -> None: - """ - Scatter-add lora_delta to intermediate_cache based on dispatch info. - - For each dispatched index d: - - token_id = token_ids[d] - - expert_id = expert_ids[d] - - Find k such that topk_ids[token_id, k] == expert_id - - intermediate_cache[token_id, k] += lora_delta[d] - """ - M, top_k, dim = intermediate_cache.shape - - # Find the top_k index for each dispatched pair - # topk_ids[token_ids] gives [num_dispatched, top_k] - # We need to find which column matches expert_ids - expanded_topk = topk_ids[token_ids] # [num_dispatched, top_k] - expert_mask = expanded_topk == expert_ids.unsqueeze( - 1 - ) # [num_dispatched, top_k] - - # Get the k index for each dispatched pair - k_indices = expert_mask.int().argmax(dim=1) # [num_dispatched] - - # Compute flat indices into intermediate_cache viewed as [M * top_k, dim] - flat_indices = token_ids * top_k + k_indices # [num_dispatched] - - # Reshape cache for scatter_add - cache_flat = intermediate_cache.view(M * top_k, dim) - - # Scatter add - cache_flat.scatter_add_( - 0, - flat_indices.unsqueeze(-1).expand(-1, dim), - lora_delta.to(cache_flat.dtype), - ) - - def _gather_dispatched_inputs( - self, - intermediate_input: torch.Tensor, # [M * top_k, dim] - token_ids: torch.Tensor, # [num_dispatched] - expert_ids: torch.Tensor, # [num_dispatched] - topk_ids: torch.Tensor, # [M, top_k] - M: int, - top_k: int, - ) -> torch.Tensor: - """ - Gather intermediate inputs for dispatched (token, expert) pairs. - - Returns tensor of shape [num_dispatched, dim]. - """ - # Find which top_k slot each dispatched pair corresponds to - expanded_topk = topk_ids[token_ids] # [num_dispatched, top_k] - expert_mask = expanded_topk == expert_ids.unsqueeze(1) - k_indices = expert_mask.int().argmax(dim=1) # [num_dispatched] - - # Compute flat indices - flat_indices = token_ids * top_k + k_indices - - # Gather - return intermediate_input[flat_indices] - - def _gather_router_weights( - self, - topk_weights: torch.Tensor, # [M, top_k] - token_ids: torch.Tensor, # [num_dispatched] - expert_ids: torch.Tensor, # [num_dispatched] - topk_ids: torch.Tensor, # [M, top_k] - ) -> torch.Tensor: - """ - Gather router weights for dispatched (token, expert) pairs. - - Returns tensor of shape [num_dispatched]. - """ - # Find which top_k slot each dispatched pair corresponds to - expanded_topk = topk_ids[token_ids] # [num_dispatched, top_k] - expert_mask = expanded_topk == expert_ids.unsqueeze(1) - k_indices = expert_mask.int().argmax(dim=1) # [num_dispatched] - - # Gather weights - expanded_weights = topk_weights[token_ids] # [num_dispatched, top_k] - return expanded_weights[ - torch.arange(len(token_ids), device=topk_weights.device), k_indices - ] diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index f61a76a86823..4976a072a220 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,6 +1,7 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward from .embedding_lora_a import embedding_lora_a_fwd +from .fused_moe_lora_kernel import fused_moe_lora from .gate_up_lora_b import gate_up_lora_b_fwd from .per_expert_lora_moe import per_expert_lora_forward from .qkv_lora_b import qkv_lora_b_fwd @@ -15,5 +16,6 @@ "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", "per_expert_lora_forward", + "fused_moe_lora", "embedding_lora_a_fwd", ] diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py new file mode 100644 index 000000000000..e7d730c6bb70 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -0,0 +1,677 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/lora/ops/triton_ops/fused_moe_lora_op.py + +import torch +import triton +import triton.language as tl + +from sglang.srt.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + +# Import SGLang's standard PDL support detection + +from sgl_kernel.utils import is_arch_support_pdl + + +_LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {} + + +def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): + """ + `_LORA_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + + if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None: + return ptr_tensor + + tensor_ptrs = [] + for lora_weight in lora_weights: + tensor_ptrs.append(lora_weight.data_ptr()) + ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + + _LORA_PTR_DICT[key] = ptr_tensor + return _LORA_PTR_DICT.get(key) + + +@triton.jit( + do_not_specialize=[ + "num_valid_tokens", + "EM", + "stride_tl", + "stride_el", + "slice_a_size", + "slice_c_size", + ] +) +def _fused_moe_lora_kernel( + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + num_experts, + lora_ids, + adapter_enabled, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_bl, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_tl, + stride_el, + slice_a_size, + slice_c_size, + # Meta-parameters + num_slice_a: tl.constexpr, + num_slice_c: tl.constexpr, + top_k: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + USE_GDC: tl.constexpr, + launch_pdl: tl.constexpr, + IS_PRIMARY: tl.constexpr, +): + pid = tl.program_id(axis=0) + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + lora_id = tl.load(lora_ids + lora_idx) + + if lora_id == -1: + # Early exit for the no-lora case. + return + moe_enabled = tl.load(adapter_enabled + lora_id) + if moe_enabled == 0: + # Early exit for the no moe lora case. + return + max_loras = tl.num_programs(axis=2) + grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + + # calculate pid_m,pid_n + pid_sk = pid % SPLIT_K + pid_m_n = pid // SPLIT_K + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_m_n // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) + pid_n = (pid_m_n % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + # get the expert_id to process curr shard + ind = lora_id * stride_el + pid_m + expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) + if expert_id == -1: + return + # get a_ptr,b_ptr,c_ptr + cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size + cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) + cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + token_ind = stride_tl * lora_id + offs_token_id + offs_token = tl.load( + sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 + ) + token_mask = offs_token < num_valid_tokens + + # get a_ptrs,b_ptrs + a_ptrs = cur_a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + b_ptrs = ( + cur_b_ptr + + lora_id * stride_bl + + expert_id * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if USE_GDC and IS_PRIMARY: + # GDC launch dependents hints the runtime system to launch dependent kernels. + tl.extra.cuda.gdc_launch_dependents() + + # accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # GDC wait waits for ALL programs in the prior kernel to complete + # before continuing. + if USE_GDC and not IS_PRIMARY: + tl.extra.cuda.gdc_wait() + + for k in range(0, grid_k): + k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) + # pre-fetch lora weight + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), + other=0.0, + ) + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + accumulator = accumulator.to(c_ptr.dtype.element_ty) + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed") + + +@torch.inference_mode() +def _fused_moe_lora_shrink( + a_intermediate_cache1: torch.Tensor, + # (num_slices, num_tokens, top_k_num, max_lora_rank) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + w1_lora_a_stacked = lora_a_stacked[0] + use_gdc = is_arch_support_pdl() + shrink_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, + "USE_GDC": use_gdc, + "launch_pdl": use_gdc, # triton kernel metadata + } + + b_ptr = _get_ptr(lora_a_stacked, device) + + grid = lambda META: ( + split_k + * triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_a_stacked), + lora_a_stacked[0].shape[0], + ) + _fused_moe_lora_kernel[grid]( + qcurr_hidden_states, + b_ptr, + a_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + lora_ids, + adapter_enabled, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + w1_lora_a_stacked.stride(0), + w1_lora_a_stacked.stride(1), + w1_lora_a_stacked.stride(3), + w1_lora_a_stacked.stride(2), + a_intermediate_cache1.stride(2), + a_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + slice_a_size=qcurr_hidden_states.numel(), + slice_c_size=a_intermediate_cache1.numel() // num_slices, + num_slice_a=1, + num_slice_c=num_slices, + top_k=1 if mul_routed_weight else top_k_num, + MUL_ROUTED_WEIGHT=False, + IS_PRIMARY=True, + **shrink_config, + ) + + +@torch.inference_mode() +def _fused_moe_lora_expand( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) + b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size) + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, + offset: int = 0, +) -> None: + b_ptr = _get_ptr(lora_b_stacked, device) + K = max_lora_rank + N = w1_output_dim_size + + w1_lora_b_stacked = lora_b_stacked[0] + + a_intermediate_cache1 = a_intermediate_cache1.view( + -1, a_intermediate_cache1.shape[3] + ) + + use_gdc = is_arch_support_pdl() + expand_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, # Set split_k = 1 for expand calls + "USE_GDC": use_gdc, + "launch_pdl": use_gdc, # triton kernel metadata + } + + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_b_stacked), + lora_b_stacked[0].shape[0], + ) + _fused_moe_lora_kernel[grid]( + a_intermediate_cache1, + b_ptr, + b_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + lora_ids, + adapter_enabled, + a_intermediate_cache1.stride(0), + a_intermediate_cache1.stride(1), + w1_lora_b_stacked.stride(0), + w1_lora_b_stacked.stride(1), + w1_lora_b_stacked.stride(3), + w1_lora_b_stacked.stride(2), + b_intermediate_cache1.stride(2), + b_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + slice_a_size=a_intermediate_cache1.numel() // num_slices, + slice_c_size=b_intermediate_cache1.numel() // num_slices, + num_slice_a=num_slices, + num_slice_c=num_slices, + top_k=1, + MUL_ROUTED_WEIGHT=mul_routed_weight, + IS_PRIMARY=False, + **expand_config, + ) + for i in range(num_slices): + output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i] + + +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, + fully_sharded: bool = False, + offset: int = 0, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + a_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, max_lora_rank), + dtype=output.dtype, + device=device, + ) + + b_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, w1_output_dim_size), + dtype=output.dtype, + device=device, + ) + + _fused_moe_lora_shrink( + a_intermediate_cache1, + qcurr_hidden_states, + lora_a_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + shrink_block_size_m, + shrink_block_size_n, + shrink_block_size_k, + shrink_group_size_m, + shrink_num_warps, + shrink_num_stages, + shrink_split_k, + mul_routed_weight, + ) + + if fully_sharded: + if max_lora_rank == w1_lora_b_stacked.shape[-1]: + a_intermediate_cache1 = tensor_model_parallel_all_reduce( + a_intermediate_cache1 + ) + else: + a_intermediate_cache1 = tensor_model_parallel_all_gather( + a_intermediate_cache1 + ) + + # reset max_lora_rank to the full rank after allgather + max_lora_rank = a_intermediate_cache1.shape[-1] + + _fused_moe_lora_expand( + output, + a_intermediate_cache1, + b_intermediate_cache1, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + max_lora_rank, + w1_output_dim_size, + expand_block_size_m, + expand_block_size_n, + expand_block_size_k, + expand_group_size_m, + expand_num_warps, + expand_num_stages, + expand_split_k, + mul_routed_weight, + offset, + ) + + +def _fused_moe_lora_fake( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_shrink_fake( + a_intermediate_cache1: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_expand_fake( + output: torch.Tensor, + a_intermediate_cache1: torch.Tensor, + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +# Register as SGLang custom ops following the same pattern as other ops +try: + from sglang.srt.utils.common import direct_register_custom_op + + direct_register_custom_op( + op_name="fused_moe_lora", + op_func=_fused_moe_lora, + mutates_args=["output"], + fake_impl=_fused_moe_lora_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_shrink", + op_func=_fused_moe_lora_shrink, + mutates_args=["a_intermediate_cache1"], + fake_impl=_fused_moe_lora_shrink_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_expand", + op_func=_fused_moe_lora_expand, + mutates_args=["output"], + fake_impl=_fused_moe_lora_expand_fake, + ) + + # Export through torch.ops.sglang namespace + fused_moe_lora = torch.ops.sglang.fused_moe_lora + fused_moe_lora_shrink = torch.ops.sglang.fused_moe_lora_shrink + fused_moe_lora_expand = torch.ops.sglang.fused_moe_lora_expand + +except AttributeError: + fused_moe_lora = _fused_moe_lora + fused_moe_lora_shrink = _fused_moe_lora_shrink + fused_moe_lora_expand = _fused_moe_lora_expand From ac4a0082baba0cdd93b21512926a55067c619d03 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:09:04 -0500 Subject: [PATCH 058/150] convert strings to int where necessary --- python/sglang/srt/lora/lora_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index e243e7e0b87a..611b2d33cffd 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -76,7 +76,7 @@ def __init__( self.lora_added_tokens_size: Optional[int] = None # Track which LoRA adapters are enabled (loaded/available) - self.adapter_enabled = torch.zeros(max_loras_per_batch + 1, dtype=torch.int32, device=self.device) +git self.adapter_enabled = torch.zeros(int(self.max_loras_per_batch) + 1, dtype=torch.int32, device=self.device) # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -146,7 +146,7 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: self.num_pinned_loras += int(lora_ref.pinned) # Mark this adapter as enabled - self.adapter_enabled[lora_ref.lora_id] = 1 + self.adapter_enabled[int(lora_ref.lora_id)] = 1 except Exception as e: return self.create_lora_update_result( success=False, @@ -210,7 +210,7 @@ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: self.num_pinned_loras -= int(lora_ref.pinned) # Mark this adapter as disabled - self.adapter_enabled[lora_ref.lora_id] = 0 + self.adapter_enabled[int(lora_ref.lora_id)] = 0 except Exception as e: return self.create_lora_update_result( success=False, From 305acc9ff91a9c43d9c3646a81067fc12efe4509 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:11:37 -0500 Subject: [PATCH 059/150] fix --- python/sglang/srt/lora/lora_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 611b2d33cffd..2969bba120f8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -76,7 +76,9 @@ def __init__( self.lora_added_tokens_size: Optional[int] = None # Track which LoRA adapters are enabled (loaded/available) -git self.adapter_enabled = torch.zeros(int(self.max_loras_per_batch) + 1, dtype=torch.int32, device=self.device) + self.adapter_enabled = torch.zeros( + int(self.max_loras_per_batch) + 1, dtype=torch.int32, device=self.device + ) # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy From bf831a99e13c31fe0903e780647f94936aadf3ba Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:18:47 -0500 Subject: [PATCH 060/150] fix --- python/sglang/srt/lora/layers.py | 14 +++++--------- python/sglang/srt/lora/lora_manager.py | 12 +----------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 56ab959a7f0f..98348a16044b 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -594,10 +594,8 @@ def __init__( self, base_layer: nn.Module, lora_backend: BaseLoRABackend, - adapter_enabled: torch.Tensor = None, ): super().__init__(base_layer, lora_backend) - self.adapter_enabled = adapter_enabled # LoRA tensors will be set by LoRAManager self.gate_up_lora_a_weights = None self.gate_up_lora_b_weights = None @@ -643,8 +641,9 @@ def _get_lora_info( lora_ranks = batch_info.lora_ranks # [num_loras] scalings = batch_info.scalings # [num_loras] - # Get adapter_enabled from layer instance, slice to match current batch - adapter_enabled = self.adapter_enabled[:len(lora_ranks)] if self.adapter_enabled is not None else torch.zeros(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) + # Create adapter_enabled tensor for the current batch + # All LoRAs in the batch are enabled by definition + adapter_enabled = torch.ones(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) # Use precomputed per-token LoRA indices from forward batch lora_indices = self.lora_backend.forward_batch.token_lora_indices @@ -769,7 +768,7 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def get_lora_layer( - layer: nn.Module, lora_backend: BaseLoRABackend, adapter_enabled: torch.Tensor = None + layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: # FusedMoE is now imported at the top of the file # FusedMoEWithLoRA is now defined in this file @@ -786,9 +785,6 @@ def get_lora_layer( } for src_layer_type, lora_layer_type in supported_layer_types.items(): if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck - if src_layer_type == FusedMoE: - ret = lora_layer_type(layer, lora_backend, adapter_enabled) - else: - ret = lora_layer_type(layer, lora_backend) + ret = lora_layer_type(layer, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 2969bba120f8..93c07740c493 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -75,10 +75,6 @@ def __init__( self.tp_rank: int = tp_rank self.lora_added_tokens_size: Optional[int] = None - # Track which LoRA adapters are enabled (loaded/available) - self.adapter_enabled = torch.zeros( - int(self.max_loras_per_batch) + 1, dtype=torch.int32, device=self.device - ) # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -146,9 +142,6 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: # keep metadata for displayed messages self.lora_refs[lora_ref.lora_id] = lora_ref self.num_pinned_loras += int(lora_ref.pinned) - - # Mark this adapter as enabled - self.adapter_enabled[int(lora_ref.lora_id)] = 1 except Exception as e: return self.create_lora_update_result( success=False, @@ -210,9 +203,6 @@ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput: del self.loras[lora_ref.lora_id] del self.lora_refs[lora_ref.lora_id] self.num_pinned_loras -= int(lora_ref.pinned) - - # Mark this adapter as disabled - self.adapter_enabled[int(lora_ref.lora_id)] = 0 except Exception as e: return self.create_lora_update_result( success=False, @@ -541,7 +531,7 @@ def init_memory_pool(self): def set_lora_module(self, module_name, module): """Wrap any module (standard or MoE) with LoRA support.""" - lora_module = get_lora_layer(module, self.lora_backend, self.adapter_enabled) + lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) return lora_module From 7c5880a8cca934a098b4aa1a0ceb22d2a709fb44 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:26:57 -0500 Subject: [PATCH 061/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 37 ++++++++++------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 532cc011d014..29238f6d1370 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -305,7 +305,6 @@ def run( else: out_hidden_states = torch.empty_like(hidden_states) - # output shape: [M, hidden_dim] and in original token order since we do not pass in c_sorted=True. If we # want to get the output in sorted order by expert, we can pass in c_sorted=True. We use no_combine=False because the next # LoRA computation requires input to be [M, hidden_dim] @@ -340,7 +339,7 @@ def run( block_shape=block_shape, ) - # ============================================================ + # ============================================================ # Stage 3.5: Add LoRA down delta BEFORE final reduction # ============================================================ # intermediate_cache2 is in the original token order and token-major order. @@ -355,7 +354,6 @@ def run( # we still need to combine the output - # ============================================================ # Stage 4: Final reduction (sum across top_k) # ============================================================ @@ -429,16 +427,13 @@ def _add_lora_gate_up_delta( # Output shape: [M, top_k, gate_up_dim] # Hidden_states shape: [M, hidden_dim] (handles token duplication internally) - num_tokens_post_padded_formatted = torch.tensor([num_dispatched], dtype=torch.int32, device=hidden_states.device) + num_tokens_post_padded_formatted = torch.tensor( + [num_dispatched], dtype=torch.int32, device=hidden_states.device + ) actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) - # Handle multi-LoRA: stack weights for all loaded LoRAs - # lora_info.gate_up_lora_a_weights shape: [num_loras, num_experts, max_rank, hidden_dim] - # Note: LoRA scaling factors (lora_info.lora_scalings) are already applied to weights during loading - max_loras = len(lora_info.lora_ranks) - - lora_a_stacked = [lora_info.gate_up_lora_a_weights[i] for i in range(max_loras)] - lora_b_stacked = [lora_info.gate_up_lora_b_weights[i] for i in range(max_loras)] + lora_a_stacked = [lora_info.gate_up_lora_a_weights] + lora_b_stacked = [lora_info.gate_up_lora_b_weights] fused_moe_lora( output=intermediate_cache, @@ -469,8 +464,6 @@ def _add_lora_gate_up_delta( expand_split_k=1, ) - - def _add_lora_down_delta( self, intermediate_input: torch.Tensor, # [M * top_k, intermediate_dim] @@ -495,7 +488,9 @@ def _add_lora_down_delta( # Data format adaptation for vLLM kernel num_dispatched_down = lora_info.token_ids.shape[0] - num_tokens_post_padded_formatted = torch.tensor([num_dispatched_down], dtype=torch.int32, device=intermediate_input.device) + num_tokens_post_padded_formatted = torch.tensor( + [num_dispatched_down], dtype=torch.int32, device=intermediate_input.device + ) actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) # Handle multi-LoRA: stack weights for all loaded LoRAs @@ -504,11 +499,15 @@ def _add_lora_down_delta( max_loras = len(lora_info.lora_ranks) # Validate weight dimensions match expectations - assert lora_info.down_lora_a_weights.shape[0] == max_loras, f"Expected {max_loras} LoRAs, got {lora_info.down_lora_a_weights.shape[0]}" - assert lora_info.adapter_enabled.shape[0] >= max_loras, f"adapter_enabled too small: {lora_info.adapter_enabled.shape[0]} < {max_loras}" + assert ( + lora_info.down_lora_a_weights.shape[0] == max_loras + ), f"Expected {max_loras} LoRAs, got {lora_info.down_lora_a_weights.shape[0]}" + assert ( + lora_info.adapter_enabled.shape[0] >= max_loras + ), f"adapter_enabled too small: {lora_info.adapter_enabled.shape[0]} < {max_loras}" - lora_a_stacked = [lora_info.down_lora_a_weights[i] for i in range(max_loras)] - lora_b_stacked = [lora_info.down_lora_b_weights[i] for i in range(max_loras)] + lora_a_stacked = [lora_info.down_lora_a_weights] + lora_b_stacked = [lora_info.down_lora_b_weights] fused_moe_lora( output=intermediate_cache, @@ -538,5 +537,3 @@ def _add_lora_down_delta( expand_num_stages=2, expand_split_k=1, ) - - From fec49f1aa398df7f0478288df6106c643401b90d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:30:33 -0500 Subject: [PATCH 062/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 29238f6d1370..797863efa9c9 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -432,6 +432,10 @@ def _add_lora_gate_up_delta( ) actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) + # Skip LoRA computation if no LoRA adapters have non-zero rank + if actual_max_lora_rank == 0: + return + lora_a_stacked = [lora_info.gate_up_lora_a_weights] lora_b_stacked = [lora_info.gate_up_lora_b_weights] @@ -493,6 +497,10 @@ def _add_lora_down_delta( ) actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) + # Skip LoRA computation if no LoRA adapters have non-zero rank + if actual_max_lora_rank == 0: + return + # Handle multi-LoRA: stack weights for all loaded LoRAs # lora_info.down_lora_a_weights shape: [num_loras, num_experts, max_rank, intermediate_dim] # Note: LoRA scaling factors (lora_info.lora_scalings) are already applied to weights during loading From 307abef823c2df222582bf0448db4bb95fd967cb Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:34:44 -0500 Subject: [PATCH 063/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 797863efa9c9..94fac553af32 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -251,6 +251,7 @@ def run( topk_ids=topk_ids, topk_weights=topk_weights, lora_info=lora_info, + num_tokens_post_padded=num_tokens_post_padded, ) # ============================================================ @@ -350,6 +351,7 @@ def run( topk_weights=topk_weights, apply_router_weight_on_input=apply_router_weight_on_input, lora_info=lora_info, + num_tokens_post_padded=num_tokens_post_padded, ) # we still need to combine the output @@ -410,6 +412,7 @@ def _add_lora_gate_up_delta( topk_ids: torch.Tensor, # [M, top_k] topk_weights: torch.Tensor, # [M, top_k] lora_info: LoRAInfo, + num_tokens_post_padded: torch.Tensor, ) -> None: """ Add LoRA gate_up delta to intermediate_cache in-place. @@ -427,9 +430,10 @@ def _add_lora_gate_up_delta( # Output shape: [M, top_k, gate_up_dim] # Hidden_states shape: [M, hidden_dim] (handles token duplication internally) - num_tokens_post_padded_formatted = torch.tensor( - [num_dispatched], dtype=torch.int32, device=hidden_states.device - ) + # Create num_tokens_post_padded tensor for vLLM kernel + # It expects shape (max_loras,) with the same value for each LoRA + num_loras = len(lora_info.lora_ranks) + num_tokens_post_padded_formatted = num_tokens_post_padded.expand(num_loras) actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) # Skip LoRA computation if no LoRA adapters have non-zero rank @@ -476,6 +480,7 @@ def _add_lora_down_delta( topk_weights: torch.Tensor, # [M, top_k] apply_router_weight_on_input: bool, lora_info: LoRAInfo, + num_tokens_post_padded: torch.Tensor, ) -> None: """ Add LoRA down delta to intermediate_cache in-place. @@ -492,9 +497,10 @@ def _add_lora_down_delta( # Data format adaptation for vLLM kernel num_dispatched_down = lora_info.token_ids.shape[0] - num_tokens_post_padded_formatted = torch.tensor( - [num_dispatched_down], dtype=torch.int32, device=intermediate_input.device - ) + # Create num_tokens_post_padded tensor for vLLM kernel + # It expects shape (max_loras,) with the same value for each LoRA + num_loras = len(lora_info.lora_ranks) + num_tokens_post_padded_formatted = num_tokens_post_padded.expand(num_loras) actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) # Skip LoRA computation if no LoRA adapters have non-zero rank From c9062b0a30bccf9684661d92166b61450f19746b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 28 Jan 2026 15:45:08 -0500 Subject: [PATCH 064/150] fix --- python/sglang/srt/lora/layers.py | 5 +++++ python/sglang/srt/lora/lora_manager.py | 1 + python/sglang/srt/lora/lora_moe_runners.py | 22 ++++++++++++---------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 98348a16044b..4b5b73e773f7 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -602,6 +602,7 @@ def __init__( self.down_lora_a_weights = None self.down_lora_b_weights = None self._lora_runner = None + self.max_lora_rank = 0 # Will be set by LoRAManager def set_lora_info( self, @@ -641,6 +642,9 @@ def _get_lora_info( lora_ranks = batch_info.lora_ranks # [num_loras] scalings = batch_info.scalings # [num_loras] + # Use global max LoRA rank set by LoRAManager + max_lora_rank = self.max_lora_rank + # Create adapter_enabled tensor for the current batch # All LoRAs in the batch are enabled by definition adapter_enabled = torch.ones(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) @@ -666,6 +670,7 @@ def _get_lora_info( lora_ranks=lora_ranks, lora_scalings=scalings, adapter_enabled=adapter_enabled, + max_lora_rank=max_lora_rank, num_experts=self.base_layer.num_experts, ) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 93c07740c493..649c2b1aa52c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -354,6 +354,7 @@ def update_lora_info(self): down_lora_a_weights=down_a, down_lora_b_weights=down_b, ) + module.max_lora_rank = self.max_lora_rank continue target_module = get_target_module_name( diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 94fac553af32..5cd06993ec7c 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -70,6 +70,7 @@ class LoRAInfo: lora_ranks: torch.Tensor # [num_loras] lora_scalings: torch.Tensor # [num_loras] adapter_enabled: torch.Tensor # [num_loras] - which adapters are enabled + max_lora_rank: int # Maximum LoRA rank across all adapters num_experts: int @@ -430,15 +431,15 @@ def _add_lora_gate_up_delta( # Output shape: [M, top_k, gate_up_dim] # Hidden_states shape: [M, hidden_dim] (handles token duplication internally) + # Skip LoRA computation if no LoRA adapters have non-zero rank + if lora_info.max_lora_rank == 0: + return + # Create num_tokens_post_padded tensor for vLLM kernel # It expects shape (max_loras,) with the same value for each LoRA num_loras = len(lora_info.lora_ranks) num_tokens_post_padded_formatted = num_tokens_post_padded.expand(num_loras) - actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) - - # Skip LoRA computation if no LoRA adapters have non-zero rank - if actual_max_lora_rank == 0: - return + actual_max_lora_rank = lora_info.max_lora_rank lora_a_stacked = [lora_info.gate_up_lora_a_weights] lora_b_stacked = [lora_info.gate_up_lora_b_weights] @@ -497,15 +498,16 @@ def _add_lora_down_delta( # Data format adaptation for vLLM kernel num_dispatched_down = lora_info.token_ids.shape[0] + + # Skip LoRA computation if no LoRA adapters have non-zero rank + if lora_info.max_lora_rank == 0: + return + # Create num_tokens_post_padded tensor for vLLM kernel # It expects shape (max_loras,) with the same value for each LoRA num_loras = len(lora_info.lora_ranks) num_tokens_post_padded_formatted = num_tokens_post_padded.expand(num_loras) - actual_max_lora_rank = int(lora_info.lora_ranks.max().item()) - - # Skip LoRA computation if no LoRA adapters have non-zero rank - if actual_max_lora_rank == 0: - return + actual_max_lora_rank = lora_info.max_lora_rank # Handle multi-LoRA: stack weights for all loaded LoRAs # lora_info.down_lora_a_weights shape: [num_loras, num_experts, max_rank, intermediate_dim] From 6e10967e51a9198827515669ae013a2cde2656c5 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 12:40:42 -0500 Subject: [PATCH 065/150] add unit tests --- .../manual/lora/test_fused_moe_lora_kernel.py | 324 ++++++++++++++++++ test/manual/lora/test_moe_lora_align_sum.py | 0 2 files changed, 324 insertions(+) create mode 100644 test/manual/lora/test_fused_moe_lora_kernel.py create mode 100644 test/manual/lora/test_moe_lora_align_sum.py diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/manual/lora/test_fused_moe_lora_kernel.py new file mode 100644 index 000000000000..6283c42a8595 --- /dev/null +++ b/test/manual/lora/test_fused_moe_lora_kernel.py @@ -0,0 +1,324 @@ +# adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_fused_moe_lora_kernel.py + +import os +import random + +import pytest +import torch + +from sglang.srt.distributed import ( + init_distributed_environment, + initialize_model_parallel, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from sglang.srt.lora.triton_ops import fused_moe_lora +from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size +from sglang.srt.utils import set_random_seed + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): + """ + Split `num_tokens` into `num_sequences` sequences. + Each sequence randomly selects 1 LoRA index from [0, max_loras), + and all tokens in that sequence are assigned this LoRA index. + + Args: + num_tokens (int): Total number of tokens. + num_sequences (int): Number of sequences to split the tokens into. + max_loras (int): Total number of available LoRA modules. + + Returns: + torch.Tensor: 1D tensor of shape [num_tokens], where each value + is the LoRA index assigned to that token. + """ + assert num_sequences > 0 and max_loras > 0 + assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" + + # Compute token distribution per sequence (distribute remainder evenly) + tokens_per_seq = num_tokens // num_sequences + remainder = num_tokens % num_sequences + + token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + + start = 0 + for seq_idx in range(num_sequences): + # Determine the token range for this sequence + end = start + tokens_per_seq + (1 if seq_idx < remainder else 0) + + # Randomly select one LoRA ID for this sequence + lora_id = random.randint(0, max_loras - 1) + + # Assign the same LoRA ID to all tokens in this sequence + token_lora_mapping[start:end] = lora_id + + start = end + + return token_lora_mapping + + +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): + """ + For each token, randomly select `top_k_num` distinct experts out of `num_experts`, + and assign normalized random weights that sum to 1. + + Args: + num_tokens (int): Total number of tokens. + num_experts (int): Total number of available experts. + top_k_num (int): Number of experts to select per token. + + Returns: + expert_indices (torch.Tensor): shape [num_tokens, top_k_num], + expert index for each token. + expert_weights (torch.Tensor): shape [num_tokens, top_k_num], + normalized weights (sum = 1 per row). + """ + assert top_k_num <= num_experts, "top_k_num must be <= num_experts" + + # Randomly select top_k_num distinct experts for each token + expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) + for i in range(num_tokens): + # Randomly choose unique expert indices + selected = torch.randperm(num_experts)[:top_k_num] + expert_indices[i] = selected + + # Generate random weights and normalize along dim=1 + expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32) + expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) + + return expert_indices, expert_weights + + +def sample_data( + num_tokens: int, + num_sequences: int, + max_loras: int, + num_experts: int, + top_k_num: int, +): + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num + ) + token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) + return topk_ids, topk_weights, token_lora_mapping + + +def use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + fully_sharded=False, + offset=0, +): + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # init output tensors + sorted_token_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + ) + expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) + + # call kernel + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + adapter_enabled, + lora_ids, + ) + + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "NUM_WARPS": 4, + "NUM_STAGES": 3, + "SPLIT_K": 1, + } + + mul_routed_weight = False + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + + fused_moe_lora( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_lora_rank, + top_k_num, + lora_ids, + adapter_enabled, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + mul_routed_weight, + fully_sharded=fully_sharded, + offset=offset, + ) + + +def use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, +): + outputs = [] + for i in range(hidden_states.shape[0]): + lora_idx = token_lora_mapping[i] + expert_ids = topk_ids[i] + lora_a = lora_a_stacked[0][lora_idx][expert_ids] + lora_b = lora_b_stacked[0][lora_idx][expert_ids] + tensors = [ + hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) + ] + outputs.append(torch.stack(tensors, dim=0)) + return torch.stack(outputs, dim=0) + + +DTYPES = [torch.float16, torch.bfloat16] +DEVICES = [f"cuda:{0}"] +SEED = [42] + + +@pytest.mark.parametrize("num_tokens", [100]) +@pytest.mark.parametrize("top_k_num", [6, 12]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("max_loras", [4, 6, 16]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + dtype, + device, + seed, +): + torch.set_default_device(device) + set_random_seed(seed) + # the number of randomly generated sentences. + num_sequences = 10 + # generate data + topk_ids, topk_weights, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + # init lora weights + lora_a_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + max_lora_rank, + K, + ), + dtype=dtype, + ) + ] + lora_b_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + N, + max_lora_rank, + ), + dtype=dtype, + ) + ] + hidden_states = torch.rand( + ( + num_tokens, + K, + ), + dtype=dtype, + ) + + # fused_moe_lora_kernel output + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + # pytorch output + output2 = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + ) + + torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/manual/lora/test_moe_lora_align_sum.py new file mode 100644 index 000000000000..e69de29bb2d1 From 3e0047a205848a7ef990a9705ff2441f9ba9c415 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 12:43:49 -0500 Subject: [PATCH 066/150] fix tests --- .../manual/lora/test_fused_moe_lora_kernel.py | 2 +- test/manual/lora/test_moe_lora_align_sum.py | 95 +++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/manual/lora/test_fused_moe_lora_kernel.py index 6283c42a8595..72c4356e8431 100644 --- a/test/manual/lora/test_fused_moe_lora_kernel.py +++ b/test/manual/lora/test_fused_moe_lora_kernel.py @@ -143,7 +143,7 @@ def use_fused_moe_lora_kernel( lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) # call kernel - ops.moe_lora_align_block_size( + moe_align_block_size( topk_ids, token_lora_mapping, num_experts, diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/manual/lora/test_moe_lora_align_sum.py index e69de29bb2d1..b8720901ddea 100644 --- a/test/manual/lora/test_moe_lora_align_sum.py +++ b/test/manual/lora/test_moe_lora_align_sum.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py +import random + +import pytest +import torch + +from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def sample_data(num_experts, max_loras, num_tokens, topk_num): + topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32) + token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32) + + for i in range(num_tokens): + pool = list(range(num_experts)) + random.shuffle(pool) + for j in range(topk_num): + topk_ids[i, j] = pool[j] + token_lora_mapping[i] = random.randint(0, max_loras - 1) + + return topk_ids.to("cuda"), token_lora_mapping.to("cuda") + + +@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 +@pytest.mark.parametrize("topk_num", [6]) +@pytest.mark.parametrize("num_experts", [64, 128, 256, 512]) +@pytest.mark.parametrize("max_loras", [2, 32]) +@pytest.mark.parametrize("block_size", [16]) +def test_moe_lora_align_block_size( + num_tokens, topk_num, num_experts, max_loras, block_size +): + # sample data + random.seed(1) + topk_ids, token_lora_mapping = sample_data( + num_experts, max_loras, num_tokens, topk_num + ) + + # compute paddings + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # init output tensors + sorted_token_ids = torch.full( + (max_loras * max_num_tokens_padded,), + topk_ids.numel(), + dtype=torch.int32, + device="cuda", + ) + expert_ids = torch.full( + (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") + adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda") + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda") + + # call kernel + moe_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + adapter_enabled, + lora_ids, + ) + + # verify values + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size) + + for lora_idx in range(max_loras): + for token_idx in range(sorted_token_ids.size(1)): + block = sorted_token_ids[lora_idx][token_idx] + indices = block[block != topk_ids.numel()] + if indices.numel() > 0: + expert_id = expert_ids[lora_idx][token_idx] + assert torch.all(topk_ids.view(-1)[indices] == expert_id) + + +if __name__ == "__main__": + pytest.main([__file__]) From e56b45151f2f20fc9ed414f27784f124f8aa2ec3 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 14:10:21 -0500 Subject: [PATCH 067/150] add unit test for lora + base path --- test/manual/lora/test_lora_moe_runner.py | 268 +++++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 test/manual/lora/test_lora_moe_runner.py diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py new file mode 100644 index 000000000000..07e5f0e0639e --- /dev/null +++ b/test/manual/lora/test_lora_moe_runner.py @@ -0,0 +1,268 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import random + +import pytest +import torch + +from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import ( + TritonMoeQuantInfo, + TritonRunnerInput, +) +from sglang.srt.lora.lora_moe_runners import LoRAInfo, TritonRunnerCoreWithLoRA +from sglang.srt.utils import set_random_seed + + +def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): + assert num_sequences > 0 and max_loras > 0 + assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" + + tokens_per_seq = num_tokens // num_sequences + remainder = num_tokens % num_sequences + + token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + + start = 0 + for seq_idx in range(num_sequences): + end = start + tokens_per_seq + (1 if seq_idx < remainder else 0) + lora_id = random.randint(0, max_loras - 1) + token_lora_mapping[start:end] = lora_id + start = end + + return token_lora_mapping + + +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): + assert top_k_num <= num_experts, "top_k_num must be <= num_experts" + + expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) + for i in range(num_tokens): + selected = torch.randperm(num_experts)[:top_k_num] + expert_indices[i] = selected + + expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32) + expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) + + return expert_indices, expert_weights + + +def sample_data(num_tokens: int, num_sequences: int, max_loras: int, num_experts: int, top_k_num: int): + topk_ids, topk_weights = assign_experts_to_tokens(num_tokens, num_experts, top_k_num) + token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) + return topk_ids, topk_weights, token_lora_mapping + + +def create_lora_info(token_lora_mapping, topk_ids, max_loras, num_experts, max_lora_rank, hidden_dim, intermediate_dim, gate_up_dim, dtype, device): + gate_up_lora_a_weights = torch.randn((max_loras, num_experts, max_lora_rank, hidden_dim), dtype=dtype, device=device) + gate_up_lora_b_weights = torch.randn((max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device) + down_lora_a_weights = torch.randn((max_loras, num_experts, max_lora_rank, intermediate_dim), dtype=dtype, device=device) + down_lora_b_weights = torch.randn((max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device) + + num_tokens = token_lora_mapping.shape[0] + dispatched_tokens = [] + dispatched_experts = [] + dispatched_loras = [] + + for token_idx in range(num_tokens): + lora_id = token_lora_mapping[token_idx] + for k in range(topk_ids.shape[1]): + expert_id = topk_ids[token_idx, k] + dispatched_tokens.append(token_idx) + dispatched_experts.append(expert_id) + dispatched_loras.append(lora_id) + + token_ids = torch.tensor(dispatched_tokens, dtype=torch.int32, device=device) + expert_ids = torch.tensor(dispatched_experts, dtype=torch.int32, device=device) + lora_ids = torch.tensor(dispatched_loras, dtype=torch.int32, device=device) + + lora_ranks = torch.full((max_loras,), max_lora_rank, dtype=torch.int32, device=device) + lora_scalings = torch.ones(max_loras, dtype=dtype, device=device) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) + + return LoRAInfo( + gate_up_lora_a_weights=gate_up_lora_a_weights, + gate_up_lora_b_weights=gate_up_lora_b_weights, + down_lora_a_weights=down_lora_a_weights, + down_lora_b_weights=down_lora_b_weights, + token_ids=token_ids, + expert_ids=expert_ids, + lora_ids=lora_ids, + lora_ranks=lora_ranks, + lora_scalings=lora_scalings, + adapter_enabled=adapter_enabled, + max_lora_rank=max_lora_rank, + num_experts=num_experts, + ) + + +def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info): + num_tokens, hidden_dim = hidden_states.shape + top_k = topk_ids.shape[1] + num_experts = w13.shape[0] + intermediate_dim = w2.shape[2] + + hidden_expanded = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) + gate_up_out = torch.zeros(num_tokens * top_k, w13.shape[1], dtype=hidden_states.dtype, device=hidden_states.device) + + for expert_id in range(num_experts): + mask = (topk_ids == expert_id).flatten() + if mask.any(): + gate_up_out[mask] = hidden_expanded[mask] @ w13[expert_id].T + if b13 is not None: + gate_up_out[mask] += b13[expert_id] + + gate_up_out = gate_up_out.view(num_tokens, top_k, -1) + + if lora_info.max_lora_rank > 0: + for i in range(num_tokens): + for k in range(top_k): + expert_id = topk_ids[i, k] + lora_id = lora_info.lora_ids[i * top_k + k] + lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] + lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] + lora_delta = lora_info.lora_scalings[lora_id] * (lora_b @ (lora_a @ hidden_states[i])) + gate_up_out[i, k] += lora_delta + + gate_up_dim = gate_up_out.shape[-1] + gate_dim = gate_up_dim // 2 + gate = gate_up_out[..., :gate_dim] + up = gate_up_out[..., gate_dim:] + intermediate_out = torch.nn.functional.silu(gate) * up + + down_out = torch.zeros(num_tokens, top_k, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device) + + for expert_id in range(num_experts): + mask = (topk_ids == expert_id) + if mask.any(): + masked_intermediate = intermediate_out[mask] + down_out[mask] = masked_intermediate @ w2[expert_id].T + if b2 is not None: + down_out[mask] += b2[expert_id] + + if lora_info.max_lora_rank > 0: + for i in range(num_tokens): + for k in range(top_k): + expert_id = topk_ids[i, k] + lora_id = lora_info.lora_ids[i * top_k + k] + lora_a = lora_info.down_lora_a_weights[lora_id, expert_id] + lora_b = lora_info.down_lora_b_weights[lora_id, expert_id] + lora_delta = lora_info.lora_scalings[lora_id] * (lora_b @ (lora_a @ intermediate_out[i, k])) + down_out[i, k] += lora_delta + + weighted_out = down_out * topk_weights.unsqueeze(-1) + final_out = weighted_out.sum(dim=1) + + return final_out + + +DTYPES = [torch.float16, torch.bfloat16] +DEVICES = ["cuda:0"] +SEED = [42] + + +@pytest.mark.parametrize("num_tokens", [32]) +@pytest.mark.parametrize("top_k_num", [2, 4]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("max_loras", [2, 4]) +@pytest.mark.parametrize("hidden_dim", [512]) +@pytest.mark.parametrize("intermediate_dim", [1024]) +@pytest.mark.parametrize("max_lora_rank", [16, 32]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_lora_moe_runner( + num_tokens, + top_k_num, + num_experts, + max_loras, + hidden_dim, + intermediate_dim, + max_lora_rank, + dtype, + device, + seed, +): + torch.set_default_device(device) + set_random_seed(seed) + + num_sequences = 4 + topk_ids, topk_weights, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + gate_up_dim = intermediate_dim * 2 + w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) + w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) + b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) + b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) + + hidden_states = torch.randn(num_tokens, hidden_dim, dtype=dtype) + + lora_info = create_lora_info( + token_lora_mapping=token_lora_mapping, + topk_ids=topk_ids, + max_loras=max_loras, + num_experts=num_experts, + max_lora_rank=max_lora_rank, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + gate_up_dim=gate_up_dim, + dtype=dtype, + device=device, + ) + + num_dispatched = num_tokens * top_k_num + sorted_token_ids = torch.arange(num_dispatched, dtype=torch.int32, device=device) + expert_ids = topk_ids.flatten().to(dtype=torch.int32, device=device) + num_tokens_post_padded = torch.tensor([num_dispatched], dtype=torch.int32, device=device) + + runner_input = TritonRunnerInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + ) + + quant_info = TritonMoeQuantInfo( + w13_weight=w13, + w2_weight=w2, + b13=b13, + b2=b2, + ) + + config = MoeRunnerConfig( + activation="silu", + is_gated=True, + inplace=False, + no_combine=False, + gemm1_alpha=None, + gemm1_clamp_limit=None, + routed_scaling_factor=1.0, + apply_router_weight_on_input=False, + ) + + runner = TritonRunnerCoreWithLoRA(config) + running_state = {"config": {}} + + lora_output = runner.run(runner_input, quant_info, running_state, lora_info=lora_info) + + torch_output = torch_naive_moe_with_lora( + hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info + ) + + torch.testing.assert_close(lora_output.hidden_states, torch_output, atol=1e-1, rtol=1e-1) \ No newline at end of file From eb24157f3dd1ddd1e1cb324c43687f4da936c6e7 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 14:42:57 -0500 Subject: [PATCH 068/150] add end to end test --- test/manual/lora/test_lora_moe_end_to_end.py | 165 +++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 test/manual/lora/test_lora_moe_end_to_end.py diff --git a/test/manual/lora/test_lora_moe_end_to_end.py b/test/manual/lora/test_lora_moe_end_to_end.py new file mode 100644 index 000000000000..3f3954571a17 --- /dev/null +++ b/test/manual/lora/test_lora_moe_end_to_end.py @@ -0,0 +1,165 @@ +""" +End-to-end test for LoRA MoE model inference using SGLang. + +This script loads a LoRA MoE model using SGLang runner, runs inference on a test dataset, +and compares outputs against gold labels to validate correctness. +""" + +import json +import os +import sys +from typing import List, Dict, Any +from urllib.request import urlopen + +import torch +from sglang.test.runners import SRTRunner + +# Configuration - set your model and LoRA paths here +MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" # Your LoRA MoE model path +LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" # REQUIRED: Your LoRA adapter path +TEST_DATA_URL = "https://huggingface.co/jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B/blob/main/training_dataset.json" # URL to test data JSON file + + +def load_test_dataset(test_data_url: str) -> List[Dict[str, Any]]: + """Load test dataset from JSON URL.""" + try: + with urlopen(test_data_url) as response: + test_dataset = json.loads(response.read().decode('utf-8')) + except Exception as e: + raise RuntimeError(f"Failed to load test data from URL {test_data_url}: {e}") + + return test_dataset + + +def run_lora_moe_inference_test(): + """Run end-to-end test for LoRA MoE model inference using SGLang.""" + + print("=== LoRA MoE End-to-End Test (SGLang) ===\n") + + print(f"Model: {MODEL_PATH}") + print(f"LoRA Path: {LORA_PATH}") + print(f"Test Data URL: {TEST_DATA_URL}") + print() + + # Load test dataset + try: + test_dataset = load_test_dataset(TEST_DATA_URL) + print(f"Loaded {len(test_dataset)} test cases from {TEST_DATA_URL}") + except Exception as e: + print(f"Error loading test data: {e}") + return False + + # Initialize results tracking + results = [] + total_tests = len(test_dataset) + correct_predictions = 0 + + try: + # Initialize SGLang runner + print("Initializing SGLang runner...") + with SRTRunner( + model_path=MODEL_PATH, + torch_dtype=torch.float32, + model_type="generation", + trust_remote_code=True, + lora_paths=[LORA_PATH], + max_loras_per_batch=1, + ) as runner: + print("SGLang runner initialized successfully. Running inference tests...\n") + + # Run inference on each test case + for i, test_case in enumerate(test_dataset, 1): + instruction = test_case["instruction"] + expected_output = test_case["output"] + test_type = test_case["type"] + + print(f"Test {i}/{total_tests}: {test_type}") + print(f"Instruction: {instruction}") + print(f"Expected: '{expected_output}'") + + try: + # Run inference using SGLang runner + model_output = runner.forward( + prompts=[instruction], + max_new_tokens=50, # Adjust as needed for your model + lora_paths=[LORA_PATH], + ) + + # Extract the generated text + generated_output = model_output.output_strs[0] + print(f"Generated: '{generated_output}'") + + # Compare with expected output (exact match for simplicity) + is_correct = generated_output.strip() == expected_output.strip() + + if is_correct: + correct_predictions += 1 + print("✓ PASS") + else: + print("✗ FAIL") + + # Store result + results.append({ + "test_id": i, + "type": test_type, + "instruction": instruction, + "expected": expected_output, + "generated": generated_output, + "correct": is_correct + }) + + except Exception as e: + print(f"✗ ERROR: {e}") + results.append({ + "test_id": i, + "type": test_type, + "instruction": instruction, + "expected": expected_output, + "generated": f"ERROR: {e}", + "correct": False + }) + + print("-" * 50) + + # Print final statistics + accuracy = correct_predictions / total_tests * 100 + + print("\n=== Test Results ===") + print(f"Total tests: {total_tests}") + print(f"Correct predictions: {correct_predictions}") + print(".2f") + print() + + # Print detailed results + print("Detailed Results:") + for result in results: + status = "PASS" if result["correct"] else "FAIL" + print(f"Test {result['test_id']}: {result['type']} - {status}") + + # Group by type + type_stats = {} + for result in results: + test_type = result["type"] + if test_type not in type_stats: + type_stats[test_type] = {"total": 0, "correct": 0} + type_stats[test_type]["total"] += 1 + if result["correct"]: + type_stats[test_type]["correct"] += 1 + + print("\nResults by Type:") + for test_type, stats in type_stats.items(): + type_accuracy = stats["correct"] / stats["total"] * 100 + print(f" {test_type}: {stats['correct']}/{stats['total']} ({type_accuracy:.1f}%)") + + return accuracy >= 0 # Always return True for now, adjust threshold as needed + + except Exception as e: + print(f"Test failed with error: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = run_lora_moe_inference_test() + sys.exit(0 if success else 1) From c8bbc2521ed4397c98f347fab5d82e872cd3875b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 15:18:10 -0500 Subject: [PATCH 069/150] fix layer_id issue --- python/sglang/srt/lora/lora_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 649c2b1aa52c..638591215aad 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -75,7 +75,6 @@ def __init__( self.tp_rank: int = tp_rank self.lora_added_tokens_size: Optional[int] = None - # Store eviction policy from server args self.eviction_policy = server_args.lora_eviction_policy @@ -354,7 +353,6 @@ def update_lora_info(self): down_lora_a_weights=down_a, down_lora_b_weights=down_b, ) - module.max_lora_rank = self.max_lora_rank continue target_module = get_target_module_name( @@ -587,6 +585,7 @@ def init_lora_modules(self): if isinstance(module, FusedMoE) and all( x in self.target_modules for x in ["gate_up_proj", "down_proj"] ): + layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module ) From 64c3d96308f89ccc65f74a72b76eff40142d59c3 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 16:37:32 -0500 Subject: [PATCH 070/150] Add moe lora align sum kernel --- .../csrc/moe/moe_lora_align_sum_kernel.cu | 764 ++++++++++++++++++ test/manual/lora/test_moe_lora_align_sum.py | 26 +- 2 files changed, 787 insertions(+), 3 deletions(-) create mode 100644 sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu new file mode 100644 index 000000000000..1203217eb754 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -0,0 +1,764 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu + +#include +#include +#include + +#include + +#include + +#include "utils.h" + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +namespace moe { +namespace batched_moe_align_block_size { + +// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. +static constexpr int32_t num_threads = 1024; +static constexpr int32_t num_blocks = 1; +__global__ void batched_moe_align_block_size_kernel( + int32_t const num_batches, int32_t const max_tokens_per_batch, + int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, + int32_t* __restrict__ num_tokens_post_pad) { + // TODO(varun): This is a naive implementation. Could be optimized. + + size_t const batch_id = threadIdx.x; + size_t const stride = blockDim.x * gridDim.x; + int32_t const num_blocks_per_batch = + CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = + num_blocks_per_batch * num_batches * block_size; + int32_t const block_ids_size = sorted_ids_size / block_size; + int32_t const SENTINEL = + num_batches * max_tokens_per_batch; // To denote invalid entries. + // Intialize sorted_ids + for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { + sorted_ids[i] = SENTINEL; + } + // Intialize expert_ids with -1 + for (size_t i = threadIdx.x; i < block_ids_size; i += stride) { + block_ids[i] = -1; + } + + int32_t b_num_tokens = 0; + if (batch_id < num_batches) { + b_num_tokens = batch_num_tokens[batch_id]; + } + int32_t const ceil_b_num_tokens = + CEILDIV(b_num_tokens, block_size) * block_size; + + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val); + __syncthreads(); + + bool const is_last_batch = batch_id == (num_batches - 1); + if (is_last_batch) { + *num_tokens_post_pad = cumsum_val + ceil_b_num_tokens; + } + + if (batch_id < num_batches) { + int32_t const batch_offset = batch_id * max_tokens_per_batch; + for (size_t i = 0; i < b_num_tokens; ++i) { + sorted_ids[cumsum_val + i] = batch_offset + i; + } + + int32_t const block_start = cumsum_val / block_size; + int32_t const num_blocks = ceil_b_num_tokens / block_size; + for (size_t i = 0; i < num_blocks; ++i) { + block_ids[block_start + i] = batch_id; + } + } +} +} // namespace batched_moe_align_block_size + +template +__device__ void _moe_align_block_size( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id, + int32_t topk_num, int32_t* token_mask, bool has_expert_map) { + extern __shared__ int32_t shared_counts[]; + + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + int cumsum_offset = (num_experts + 1) * model_offset; + + // Use separate threadblocks to fill sorted_token_ids. + // This is safe since the current kernel does not use sorted_token_ids. + if (blockIdx.x % 2) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += blockDim.x) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + return; + } + + const int warp_id = threadIdx.x / WARP_SIZE; + const int my_expert_start = warp_id * experts_per_warp; + + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < padded_num_experts) { + shared_counts[warp_id * experts_per_warp + i] = 0; + } + } + + __syncthreads(); + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], + mask); + } + + __syncthreads(); + + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int expert_count = 0; + int expert_id = threadIdx.x; + if (expert_id < num_experts) { + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + expert_count = CEILDIV(expert_count, block_size) * block_size; + } + + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); + if (expert_id <= num_experts) { + cumsum[cumsum_offset + expert_id] = cumsum_val; + } + + if (expert_id == num_experts) { + total_tokens_post_pad[model_offset] = cumsum_val; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[cumsum_offset + threadIdx.x]; + i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = + cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } +} + +template +__device__ void _moe_align_block_size_small_batch_expert( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks, + int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num, + int32_t* token_mask, bool has_expert_map) { + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + + // Use an additional group of threads to fill sorted_token_ids. + // Since the current kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + tokens_cnts[(tid + 1) * num_experts + expert_id] += mask; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; + } + total_tokens_post_pad[model_offset] = + static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = tid; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + + if (token_mask == nullptr || token_mask[i / topk_num]) { + sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } + } +} + +template +__device__ void _count_and_sort_expert_tokens( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask, + int32_t model_offset, int32_t topk_num, bool has_expert_map) { + const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.y; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } + + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } + + if (token_mask == nullptr || token_mask[i / topk_num]) { + int32_t rank_post_pad = atomicAdd( + &cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = + i; + } + } +} + +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t topk_num, bool has_expert_map) { + _moe_align_block_size( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), + 0, 0, topk_num, nullptr, has_expert_map); +} + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { + _count_and_sort_expert_tokens( + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map); +} + +template +__global__ void moe_sum_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += __ldg(&input[token_idx * TOPK * d + k * d + idx]); + } + out[token_idx * d + idx] = x; + } +} + +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t topk_num, + bool has_expert_map) { + _moe_align_block_size_small_batch_expert( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, + has_expert_map); +} + +template +__global__ void moe_lora_align_block_size_kernel( + scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping, + int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, int32_t experts_per_warp, + int32_t padded_num_experts, int32_t* lora_ids, + int32_t* __restrict__ token_mask, bool has_expert_map) { + int lora_idx = blockIdx.x / 2; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } + + // Populate the token_mask based on the token-LoRA mapping + int num_tokens = numel / topk_num; + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + + for (int i = 0; i < num_tokens; i++) { + token_mask[(lora_id * num_tokens) + i] = + (int)token_lora_mapping[i] == lora_id; + } + } + + __syncthreads(); + + _moe_align_block_size( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num, + &token_mask[(lora_id * num_tokens)], has_expert_map); +} + +template +__global__ void lora_count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask, + int32_t* lora_ids, bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1) { + return; + } + + int num_tokens = numel / topk_num; + + _count_and_sort_expert_tokens( + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id, + topk_num, has_expert_map); +} + +template +__global__ void moe_lora_align_block_size_small_batch_expert_kernel( + scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, + int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids, + int32_t* token_mask, bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } + + int num_tokens = numel / topk_num; + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + + for (int i = 0; i < num_tokens; i++) { + token_mask[(lora_id * num_tokens) + i] = + (int)token_lora_mapping[i] == lora_id; + } + } + + __syncthreads(); + + _moe_align_block_size_small_batch_expert( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks, + -1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)], + has_expert_map); +} + +} // namespace moe + +// taken from +// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int experts_per_warp = WARP_SIZE; + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + bool has_expert_map = maybe_expert_map.has_value(); + torch::Tensor expert_map; + if (has_expert_map) { + expert_map = maybe_expert_map.value(); + } else { + expert_map = torch::empty({0}, options_int); + } + + DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = + ((threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + auto small_batch_expert_kernel = + moe::moe_align_block_size_small_batch_expert_kernel< + scalar_t, fill_threads>; + small_batch_expert_kernel<<<1, fill_threads + threads, + shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, block_size, + topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1), + has_expert_map); + } else { + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); + auto align_kernel = moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = + num_warps * experts_per_warp * sizeof(int32_t); + + // launch two threadblocks + // blockIdx.x == 0: counting experts and aligning + // blockIdx.x == 1: filling sorted_token_ids + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, padded_num_experts, + experts_per_warp, block_size, topk_ids.numel(), + cumsum_buffer.data_ptr(), sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + dim3 gridDims(1, actual_blocks); + + auto sort_kernel = + moe::count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), expert_map.data_ptr(), + topk_ids.numel(), num_experts, sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); + } + }); +} + +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { + namespace batched_kernel = moe::batched_moe_align_block_size; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int32_t const B = batch_num_tokens.size(0); + int32_t const num_blocks_per_batch = + round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks = num_blocks_per_batch * B; + int64_t const sorted_ids_size = num_blocks * block_size; + + TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size); + TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size); + TORCH_CHECK(num_tokens_post_pad.size(0) == 1); + TORCH_CHECK(B <= batched_kernel::num_threads); + + batched_kernel::batched_moe_align_block_size_kernel<<< + batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( + B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr()); +} + +void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] +{ + const int hidden_size = input.size(-1); + const auto num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (topk) { + case 2: + DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + case 3: + DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + case 4: + DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + default: + at::sum_out(output, input, 1); + break; + } +} + +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids, std::optional maybe_expert_map) { + const int topk_num = topk_ids.size(1); + + TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); + + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor token_mask = + torch::empty({max_loras * topk_ids.size(0)}, options_int); + bool has_expert_map = maybe_expert_map.has_value(); + torch::Tensor expert_map; + if (has_expert_map) { + expert_map = maybe_expert_map.value(); + } else { + expert_map = torch::empty({0}, options_int); + } + + DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = max((int32_t)num_experts, 128); + const int32_t shared_mem = + (num_thread + 1) * num_experts * sizeof(int32_t) + + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, "Shared memory usage exceeds device limit."); + } + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = + moe::moe_lora_align_block_size_small_batch_expert_kernel< + scalar_t, fill_threads>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); + kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + // cumsum buffer + torch::Tensor cumsum = + torch::zeros({max_loras * (num_experts + 1)}, options_int); + + auto align_kernel = + moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + align_kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), cumsum.data_ptr(), + WARP_SIZE, padded_num_experts, lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = + moe::lora_count_and_sort_expert_tokens_kernel; + + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), cumsum.data_ptr(), + expert_map.data_ptr(), topk_ids.numel(), num_experts, + max_num_tokens_padded, topk_num, token_mask.data_ptr(), + lora_ids.data_ptr(), has_expert_map); + } + }); +} + +// TODO: Jonahbernard: remove this later +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("moe_lora_align_block_size", &moe_lora_align_block_size, + "MoE LoRA Align Block Size"); +} diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/manual/lora/test_moe_lora_align_sum.py index b8720901ddea..d378b33a3ca8 100644 --- a/test/manual/lora/test_moe_lora_align_sum.py +++ b/test/manual/lora/test_moe_lora_align_sum.py @@ -1,10 +1,29 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py import random - +import os import pytest import torch +from torch.utils.cpp_extension import load + +# ============================================================================== +# 1. JIT Compile the Kernel +# ============================================================================== +# Pointing specifically to the path you provided +source_path = ".." + +print(f"Loading kernel from: {source_path}") + +# Check if file exists to avoid confusing compilation errors +if not os.path.exists(source_path): + raise FileNotFoundError(f"Could not find CUDA file at {source_path}") -from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size +moe_ops = load( + name="moe_lora_ops_jit", + sources=[source_path], + extra_cuda_cflags=["-O3"], + verbose=True, +) +print("Kernel loaded successfully.") def round_up(x, base): @@ -63,7 +82,7 @@ def test_moe_lora_align_block_size( lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda") # call kernel - moe_align_block_size( + moe_ops.moe_lora_align_block_size( topk_ids, token_lora_mapping, num_experts, @@ -76,6 +95,7 @@ def test_moe_lora_align_block_size( num_tokens_post_pad, adapter_enabled, lora_ids, + None, ) # verify values From 0e8c05dd004ce347211cceadd01da9f375c43013 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 18:35:12 -0500 Subject: [PATCH 071/150] add call to moe lora align kernel --- python/sglang/srt/lora/lora_moe_runners.py | 60 ++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 5cd06993ec7c..b1dddc35ee32 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -525,15 +525,69 @@ def _add_lora_down_delta( lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] + # Define shrink_config for LoRA alignment + shrink_config = {"BLOCK_SIZE_M": 64} # Default block size, can be made configurable + + # Call moe_lora_align_block_size before the LoRA down_proj delta path + # Prepare inputs for the kernel + block_size_m = shrink_config["BLOCK_SIZE_M"] + max_loras = len(lora_info.lora_ranks) + num_tokens = M + topk_num = topk_ids.shape[1] + + # Calculate max_num_tokens_padded + max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * (block_size_m - 1) + max_num_tokens_padded = ((max_num_tokens_padded + block_size_m - 1) // block_size_m) * block_size_m + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + + # Initialize output tensors (using torch.empty like the reference implementation) + device = topk_ids.device + sorted_token_ids_lora = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=device, + ) + expert_ids_lora = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=device, + ) + num_tokens_post_padded_lora = torch.empty((max_loras,), dtype=torch.int32, device=device) + + # Create token-to-LoRA mapping (assuming all tokens use LoRA 0 for now) + token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32, device=device) + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) + + # Call the kernel directly + torch.ops.sgl_kernel.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + lora_info.num_experts, + block_size_m, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + lora_info.adapter_enabled, + lora_ids, + None, # expert_map, set to None for now + ) + + max_loras = len(lora_info.lora_ranks) + expert_ids_lora = expert_ids_lora.view(max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, lora_a_stacked=lora_a_stacked, lora_b_stacked=lora_b_stacked, topk_weights=topk_weights, # Use the routing weights passed to this function - sorted_token_ids=lora_info.token_ids.unsqueeze(0), - expert_ids=lora_info.expert_ids.unsqueeze(0), - num_tokens_post_padded=num_tokens_post_padded_formatted, + sorted_token_ids=sorted_token_ids_lora, + expert_ids=expert_ids_lora, + num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, lora_ids=lora_info.lora_ids, From a03797e597ba938e64789fb00fe193899a8d9ccc Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 31 Jan 2026 18:36:56 -0500 Subject: [PATCH 072/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 120 ++++++++++----------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index b1dddc35ee32..8306215ec0e1 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -444,15 +444,69 @@ def _add_lora_gate_up_delta( lora_a_stacked = [lora_info.gate_up_lora_a_weights] lora_b_stacked = [lora_info.gate_up_lora_b_weights] + # Define shrink_config for LoRA alignment + shrink_config = {"BLOCK_SIZE_M": 64} # Default block size, can be made configurable + + # Call moe_lora_align_block_size before the LoRA gate_up_proj delta path + # Prepare inputs for the kernel + block_size_m = shrink_config["BLOCK_SIZE_M"] + max_loras = len(lora_info.lora_ranks) + num_tokens = M + topk_num = topk_ids.shape[1] + + # Calculate max_num_tokens_padded + max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * (block_size_m - 1) + max_num_tokens_padded = ((max_num_tokens_padded + block_size_m - 1) // block_size_m) * block_size_m + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + + # Initialize output tensors (using torch.empty like the reference implementation) + device = topk_ids.device + sorted_token_ids_lora = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=device, + ) + expert_ids_lora = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=device, + ) + num_tokens_post_padded_lora = torch.empty((max_loras,), dtype=torch.int32, device=device) + + # Create token-to-LoRA mapping (assuming all tokens use LoRA 0 for now) + token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32, device=device) + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) + + # Call the kernel directly + torch.ops.sgl_kernel.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + lora_info.num_experts, + block_size_m, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + lora_info.adapter_enabled, + lora_ids, + None, # expert_map, set to None for now + ) + + max_loras = len(lora_info.lora_ranks) + expert_ids_lora = expert_ids_lora.view(max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=hidden_states, lora_a_stacked=lora_a_stacked, lora_b_stacked=lora_b_stacked, topk_weights=topk_weights, # Use actual routing weights - sorted_token_ids=lora_info.token_ids.unsqueeze(0), - expert_ids=lora_info.expert_ids.unsqueeze(0), - num_tokens_post_padded=num_tokens_post_padded_formatted, + sorted_token_ids=sorted_token_ids_lora, + expert_ids=expert_ids_lora, + num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, lora_ids=lora_info.lora_ids, @@ -525,69 +579,15 @@ def _add_lora_down_delta( lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] - # Define shrink_config for LoRA alignment - shrink_config = {"BLOCK_SIZE_M": 64} # Default block size, can be made configurable - - # Call moe_lora_align_block_size before the LoRA down_proj delta path - # Prepare inputs for the kernel - block_size_m = shrink_config["BLOCK_SIZE_M"] - max_loras = len(lora_info.lora_ranks) - num_tokens = M - topk_num = topk_ids.shape[1] - - # Calculate max_num_tokens_padded - max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * (block_size_m - 1) - max_num_tokens_padded = ((max_num_tokens_padded + block_size_m - 1) // block_size_m) * block_size_m - max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m - - # Initialize output tensors (using torch.empty like the reference implementation) - device = topk_ids.device - sorted_token_ids_lora = torch.empty( - (max_loras * max_num_tokens_padded,), - dtype=torch.int32, - device=device, - ) - expert_ids_lora = torch.empty( - (max_loras * max_num_m_blocks,), - dtype=torch.int32, - device=device, - ) - num_tokens_post_padded_lora = torch.empty((max_loras,), dtype=torch.int32, device=device) - - # Create token-to-LoRA mapping (assuming all tokens use LoRA 0 for now) - token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32, device=device) - lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) - - # Call the kernel directly - torch.ops.sgl_kernel.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - lora_info.num_experts, - block_size_m, - max_loras, - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - lora_info.adapter_enabled, - lora_ids, - None, # expert_map, set to None for now - ) - - max_loras = len(lora_info.lora_ranks) - expert_ids_lora = expert_ids_lora.view(max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) - fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, lora_a_stacked=lora_a_stacked, lora_b_stacked=lora_b_stacked, topk_weights=topk_weights, # Use the routing weights passed to this function - sorted_token_ids=sorted_token_ids_lora, - expert_ids=expert_ids_lora, - num_tokens_post_padded=num_tokens_post_padded_lora, + sorted_token_ids=lora_info.token_ids.unsqueeze(0), + expert_ids=lora_info.expert_ids.unsqueeze(0), + num_tokens_post_padded=num_tokens_post_padded_formatted, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, lora_ids=lora_info.lora_ids, From bab26e6eafbf5313ca6721e361e23889ef1a0fec Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Feb 2026 10:12:00 -0500 Subject: [PATCH 073/150] refactor to use MoE runners infra --- .../srt/layers/moe/moe_runner/runner.py | 44 ++++++++++++------- python/sglang/srt/lora/layers.py | 28 +++++------- python/sglang/srt/lora/lora_moe_runners.py | 20 ++++++--- 3 files changed, 53 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index 8b58cd3115bd..9928584cd05e 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -25,14 +25,19 @@ class MoeRunner: - def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): + def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig, lora_enabled: bool = False): self.runner_backend = runner_backend self.config = config + self.lora_enabled = lora_enabled self.fused_func = None if runner_backend.is_triton(): - self.runner_core = TritonRunnerCore(config) + if lora_enabled: + from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA + self.runner_core = TritonRunnerCoreWithLoRA(config) + else: + self.runner_core = TritonRunnerCore(config) elif runner_backend.is_triton_kernels(): self.runner_core = TritonKernelsRunnerCore(config) elif runner_backend.is_deep_gemm(): @@ -44,20 +49,22 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): else: raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") - a2a_backend_name = get_moe_a2a_backend().value - runner_backend_name = runner_backend.value + # Skip fused func if LoRA is enabled (LoRA requires non-fused path) + if not lora_enabled: + a2a_backend_name = get_moe_a2a_backend().value + runner_backend_name = runner_backend.value - # TODO(cwan): add a server argument to disable fused func - self.fused_func = FusedOpPool.get_fused_func( - a2a_backend_name, runner_backend_name - ) - - if self.runner_core is None and self.fused_func is None: - raise NotImplementedError( - f"Runner backend {runner_backend} requires a fused func for a2a backend " - f"{a2a_backend_name}, but none is registered." + # TODO(cwan): add a server argument to disable fused func + self.fused_func = FusedOpPool.get_fused_func( + a2a_backend_name, runner_backend_name ) + if self.runner_core is None and self.fused_func is None: + raise NotImplementedError( + f"Runner backend {runner_backend} requires a fused func for a2a backend " + f"{a2a_backend_name}, but none is registered." + ) + self.down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None self.meta_overlap_args: Optional[dict] = None @@ -71,10 +78,10 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): self.fused_func = None def run( - self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo + self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo, lora_info=None ) -> CombineInput: - if self.fused_func is not None: + if self.fused_func is not None and not self.lora_enabled: return self.fused_func(dispatch_output, quant_info, self.config) assert self.runner_core is not None @@ -93,7 +100,12 @@ def run( runner_input = self.pre_permute_func( dispatch_output, quant_info, self.config, running_state ) - runner_output = self.runner_core.run(runner_input, quant_info, running_state) + + # Pass lora_info to runner_core if LoRA is enabled + if self.lora_enabled: + runner_output = self.runner_core.run(runner_input, quant_info, running_state, lora_info) + else: + runner_output = self.runner_core.run(runner_input, quant_info, running_state) runner_format = self.runner_core.runner_backend.value combine_format = dispatch_output.format.value diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 4b5b73e773f7..db73850f5e07 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -718,9 +718,14 @@ def _forward_with_lora( hidden_states=hidden_states, topk_output=topk_output ) - # Create LoRA-aware runner if not already created + # Create LoRA-enabled MoeRunner if not already created if self._lora_runner is None: - self._lora_runner = TritonRunnerCoreWithLoRA(base_layer.moe_runner_config) + from sglang.srt.layers.moe.moe_runner.runner import MoeRunner + self._lora_runner = MoeRunner( + base_layer.moe_runner.runner_backend, + base_layer.moe_runner_config, + lora_enabled=True + ) # Build quant info (for unquantized, this is straightforward) quant_info = TritonMoeQuantInfo( @@ -730,25 +735,12 @@ def _forward_with_lora( b2=getattr(base_layer, "w2_weight_bias", None), ) - # Get running state (includes config from pre-permute) - from sglang.srt.layers.moe.moe_runner.triton import ( - pre_permute_standard_to_triton, - ) - from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput - - running_state = {} - runner_input = pre_permute_standard_to_triton( - dispatch_output, quant_info, base_layer.moe_runner_config, running_state - ) - - # Run with LoRA integration - runner_output = self._lora_runner.run( - runner_input, quant_info, running_state, lora_info + # Run with LoRA integration using the MoeRunner infrastructure + combine_input = self._lora_runner.run( + dispatch_output, quant_info, lora_info=lora_info ) # Combine and return - combine_input = StandardCombineInput(hidden_states=runner_output.hidden_states) - final_hidden_states = base_layer.dispatcher.combine(combine_input=combine_input) final_hidden_states = final_hidden_states[ ..., :origin_hidden_states_dim diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 8306215ec0e1..a8b96c08c7b9 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -445,7 +445,9 @@ def _add_lora_gate_up_delta( lora_b_stacked = [lora_info.gate_up_lora_b_weights] # Define shrink_config for LoRA alignment - shrink_config = {"BLOCK_SIZE_M": 64} # Default block size, can be made configurable + shrink_config = { + "BLOCK_SIZE_M": 64 + } # Default block size, can be made configurable # Call moe_lora_align_block_size before the LoRA gate_up_proj delta path # Prepare inputs for the kernel @@ -455,8 +457,12 @@ def _add_lora_gate_up_delta( topk_num = topk_ids.shape[1] # Calculate max_num_tokens_padded - max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * (block_size_m - 1) - max_num_tokens_padded = ((max_num_tokens_padded + block_size_m - 1) // block_size_m) * block_size_m + max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * ( + block_size_m - 1 + ) + max_num_tokens_padded = ( + (max_num_tokens_padded + block_size_m - 1) // block_size_m + ) * block_size_m max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m # Initialize output tensors (using torch.empty like the reference implementation) @@ -471,10 +477,14 @@ def _add_lora_gate_up_delta( dtype=torch.int32, device=device, ) - num_tokens_post_padded_lora = torch.empty((max_loras,), dtype=torch.int32, device=device) + num_tokens_post_padded_lora = torch.empty( + (max_loras,), dtype=torch.int32, device=device + ) # Create token-to-LoRA mapping (assuming all tokens use LoRA 0 for now) - token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32, device=device) + token_lora_mapping = torch.zeros( + (num_tokens,), dtype=torch.int32, device=device + ) lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) # Call the kernel directly From b76d05a0e1416ca54b01d797c7614cc8fa9603e7 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Feb 2026 10:22:22 -0500 Subject: [PATCH 074/150] update runner test case to work with refactoring --- test/manual/lora/test_lora_moe_runner.py | 30 ++++++++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py index 07e5f0e0639e..d06f053a6d35 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/manual/lora/test_lora_moe_runner.py @@ -18,11 +18,15 @@ import torch from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.runner import MoeRunner from sglang.srt.layers.moe.moe_runner.triton import ( TritonMoeQuantInfo, TritonRunnerInput, ) -from sglang.srt.lora.lora_moe_runners import LoRAInfo, TritonRunnerCoreWithLoRA +from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput +from sglang.srt.layers.moe.topk import StandardTopKOutput +from sglang.srt.layers.moe.utils import MoeRunnerBackend +from sglang.srt.lora.lora_moe_runners import LoRAInfo from sglang.srt.utils import set_random_seed @@ -254,15 +258,31 @@ def test_lora_moe_runner( gemm1_clamp_limit=None, routed_scaling_factor=1.0, apply_router_weight_on_input=False, + num_local_experts=num_experts, ) - runner = TritonRunnerCoreWithLoRA(config) - running_state = {"config": {}} + # Create StandardTopKOutput for DispatchOutput + router_logits = torch.randn(num_tokens, num_experts, dtype=dtype, device=device) # Dummy logits + topk_output = StandardTopKOutput( + topk_weights=topk_weights, + topk_ids=topk_ids, + router_logits=router_logits, + ) + + # Create StandardDispatchOutput + dispatch_output = StandardDispatchOutput( + hidden_states=hidden_states, + hidden_states_scale=None, + topk_output=topk_output, + ) - lora_output = runner.run(runner_input, quant_info, running_state, lora_info=lora_info) + # Test the full MoeRunner flow with LoRA enabled + runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) + combine_input = runner.run(dispatch_output, quant_info, lora_info) + lora_output = combine_input torch_output = torch_naive_moe_with_lora( hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info ) - torch.testing.assert_close(lora_output.hidden_states, torch_output, atol=1e-1, rtol=1e-1) \ No newline at end of file + torch.testing.assert_close(lora_output.hidden_states, torch_output, atol=1e-1, rtol=1e-1) From 06d22be890de6e1bb47f64594f54ca046f68d56d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Feb 2026 10:26:49 -0500 Subject: [PATCH 075/150] fix runner test case --- test/manual/lora/test_lora_moe_runner.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py index d06f053a6d35..98b5f1fd634d 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/manual/lora/test_lora_moe_runner.py @@ -13,6 +13,7 @@ # ============================================================================== import random +from unittest.mock import patch import pytest import torch @@ -277,9 +278,14 @@ def test_lora_moe_runner( ) # Test the full MoeRunner flow with LoRA enabled - runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) - combine_input = runner.run(dispatch_output, quant_info, lora_info) - lora_output = combine_input + # Mock global server args to avoid dependency on server initialization + class MockServerArgs: + enable_deterministic_inference = False + + with patch('sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args', return_value=MockServerArgs()): + runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) + combine_input = runner.run(dispatch_output, quant_info, lora_info) + lora_output = combine_input torch_output = torch_naive_moe_with_lora( hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info From 8dd1ef335b8ef9da0aedf7ab28f37350ab936ea2 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Feb 2026 14:11:51 -0500 Subject: [PATCH 076/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index a8b96c08c7b9..6659c7a03a7f 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -508,6 +508,10 @@ def _add_lora_gate_up_delta( expert_ids_lora = expert_ids_lora.view(max_loras, -1) sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + # TODO: Jonahbernard: we need some way to pass these on to the down_proj delta path + lora_info.expert_ids = expert_ids_lora + lora_info.token_ids = sorted_token_ids_lora # it could be an issue overwriting these two fields here. + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=hidden_states, @@ -570,7 +574,6 @@ def _add_lora_down_delta( # Create num_tokens_post_padded tensor for vLLM kernel # It expects shape (max_loras,) with the same value for each LoRA num_loras = len(lora_info.lora_ranks) - num_tokens_post_padded_formatted = num_tokens_post_padded.expand(num_loras) actual_max_lora_rank = lora_info.max_lora_rank # Handle multi-LoRA: stack weights for all loaded LoRAs @@ -595,9 +598,9 @@ def _add_lora_down_delta( lora_a_stacked=lora_a_stacked, lora_b_stacked=lora_b_stacked, topk_weights=topk_weights, # Use the routing weights passed to this function - sorted_token_ids=lora_info.token_ids.unsqueeze(0), - expert_ids=lora_info.expert_ids.unsqueeze(0), - num_tokens_post_padded=num_tokens_post_padded_formatted, + sorted_token_ids=lora_info.token_ids, # this is the token_ids from the previous stage + expert_ids=lora_info.expert_ids, # this is the expert_ids from the previous stage + num_tokens_post_padded=num_tokens_post_padded, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, lora_ids=lora_info.lora_ids, From 796db387694425ff72d1736d50cb27512c2a7d98 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Feb 2026 16:32:12 -0500 Subject: [PATCH 077/150] fix --- python/sglang/srt/lora/layers.py | 2 +- python/sglang/srt/lora/lora_moe_runners.py | 7 ++++++- test/manual/lora/test_lora_moe_runner.py | 10 +++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index db73850f5e07..14ce99895d4c 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -722,7 +722,7 @@ def _forward_with_lora( if self._lora_runner is None: from sglang.srt.layers.moe.moe_runner.runner import MoeRunner self._lora_runner = MoeRunner( - base_layer.moe_runner.runner_backend, + base_layer.quant_method.runner.runner_backend, base_layer.moe_runner_config, lora_enabled=True ) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 6659c7a03a7f..b8115beffcae 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -592,6 +592,11 @@ def _add_lora_down_delta( lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] + # TODO: Jonahbernard: is this variable needed? + num_tokens_post_padded_lora = torch.empty( + (max_loras,), dtype=torch.int32, device=topk_weights.device + ) + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, @@ -600,7 +605,7 @@ def _add_lora_down_delta( topk_weights=topk_weights, # Use the routing weights passed to this function sorted_token_ids=lora_info.token_ids, # this is the token_ids from the previous stage expert_ids=lora_info.expert_ids, # this is the expert_ids from the previous stage - num_tokens_post_padded=num_tokens_post_padded, + num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, lora_ids=lora_info.lora_ids, diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py index 98b5f1fd634d..b240d2308fa6 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/manual/lora/test_lora_moe_runner.py @@ -50,7 +50,7 @@ def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): return token_lora_mapping -def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32): assert top_k_num <= num_experts, "top_k_num must be <= num_experts" expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) @@ -58,14 +58,14 @@ def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): selected = torch.randperm(num_experts)[:top_k_num] expert_indices[i] = selected - expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32) + expert_weights = torch.rand((num_tokens, top_k_num), dtype=dtype) expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) return expert_indices, expert_weights -def sample_data(num_tokens: int, num_sequences: int, max_loras: int, num_experts: int, top_k_num: int): - topk_ids, topk_weights = assign_experts_to_tokens(num_tokens, num_experts, top_k_num) +def sample_data(num_tokens: int, num_sequences: int, max_loras: int, num_experts: int, top_k_num: int, dtype=torch.float32): + topk_ids, topk_weights = assign_experts_to_tokens(num_tokens, num_experts, top_k_num, dtype) token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) return topk_ids, topk_weights, token_lora_mapping @@ -205,7 +205,7 @@ def test_lora_moe_runner( num_sequences = 4 topk_ids, topk_weights, token_lora_mapping = sample_data( - num_tokens, num_sequences, max_loras, num_experts, top_k_num + num_tokens, num_sequences, max_loras, num_experts, top_k_num, dtype ) gate_up_dim = intermediate_dim * 2 From d0a9f9b6c9bce450f35d935036e1521373a95d02 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Feb 2026 17:59:11 -0500 Subject: [PATCH 078/150] fix --- .../sglang/srt/layers/moe/moe_runner/runner.py | 18 +++++++++++++----- python/sglang/srt/lora/layers.py | 8 ++++++-- python/sglang/srt/lora/lora_moe_runners.py | 9 +++++---- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index 9928584cd05e..e78912acd768 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -24,8 +24,12 @@ class MoeRunner: - - def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig, lora_enabled: bool = False): + def __init__( + self, + runner_backend: MoeRunnerBackend, + config: MoeRunnerConfig, + lora_enabled: bool = False, + ): self.runner_backend = runner_backend self.config = config self.lora_enabled = lora_enabled @@ -35,6 +39,7 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig, lo if runner_backend.is_triton(): if lora_enabled: from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA + self.runner_core = TritonRunnerCoreWithLoRA(config) else: self.runner_core = TritonRunnerCore(config) @@ -80,7 +85,6 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig, lo def run( self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo, lora_info=None ) -> CombineInput: - if self.fused_func is not None and not self.lora_enabled: return self.fused_func(dispatch_output, quant_info, self.config) @@ -103,9 +107,13 @@ def run( # Pass lora_info to runner_core if LoRA is enabled if self.lora_enabled: - runner_output = self.runner_core.run(runner_input, quant_info, running_state, lora_info) + runner_output = self.runner_core.run( + runner_input, quant_info, running_state, lora_info + ) else: - runner_output = self.runner_core.run(runner_input, quant_info, running_state) + runner_output = self.runner_core.run( + runner_input, quant_info, running_state + ) runner_format = self.runner_core.runner_backend.value combine_format = dispatch_output.format.value diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 14ce99895d4c..d5b70148411e 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -647,7 +647,9 @@ def _get_lora_info( # Create adapter_enabled tensor for the current batch # All LoRAs in the batch are enabled by definition - adapter_enabled = torch.ones(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) + adapter_enabled = torch.ones( + len(lora_ranks), dtype=torch.int32, device=lora_ranks.device + ) # Use precomputed per-token LoRA indices from forward batch lora_indices = self.lora_backend.forward_batch.token_lora_indices @@ -667,6 +669,7 @@ def _get_lora_info( token_ids=token_ids, expert_ids=expert_ids, lora_ids=lora_ids, + token_lora_indices=lora_indices, lora_ranks=lora_ranks, lora_scalings=scalings, adapter_enabled=adapter_enabled, @@ -721,10 +724,11 @@ def _forward_with_lora( # Create LoRA-enabled MoeRunner if not already created if self._lora_runner is None: from sglang.srt.layers.moe.moe_runner.runner import MoeRunner + self._lora_runner = MoeRunner( base_layer.quant_method.runner.runner_backend, base_layer.moe_runner_config, - lora_enabled=True + lora_enabled=True, ) # Build quant info (for unquantized, this is straightforward) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index b8115beffcae..020b61c04fef 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -66,6 +66,9 @@ class LoRAInfo: expert_ids: torch.Tensor # [num_dispatched] - expert IDs lora_ids: torch.Tensor # [num_dispatched] - LoRA adapter IDs + # Original per-token LoRA mapping (unsorted) + token_lora_indices: torch.Tensor # [num_tokens] - LoRA adapter ID for each token + # LoRA config per adapter lora_ranks: torch.Tensor # [num_loras] lora_scalings: torch.Tensor # [num_loras] @@ -481,10 +484,8 @@ def _add_lora_gate_up_delta( (max_loras,), dtype=torch.int32, device=device ) - # Create token-to-LoRA mapping (assuming all tokens use LoRA 0 for now) - token_lora_mapping = torch.zeros( - (num_tokens,), dtype=torch.int32, device=device - ) + # Use the actual token-to-LoRA mapping from the forward batch + token_lora_mapping = lora_info.token_lora_indices.to(dtype=torch.int32, device=device) lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) # Call the kernel directly From 952c8d330cd19d8d92022fc4d94073cacfe30d56 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 3 Feb 2026 19:24:16 -0500 Subject: [PATCH 079/150] fix small issues in lora moe runners --- python/sglang/srt/lora/lora_moe_runners.py | 230 ++++++++++----------- 1 file changed, 115 insertions(+), 115 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 020b61c04fef..1c67af71b791 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -38,6 +38,32 @@ ) from sglang.srt.utils import is_cuda, is_hip +import sys +import os +# ============================================================================== +# IMPORT PREBUILT MOE LORA KERNEL +# IMPORT PREBUILT MOE LORA KERNEL +# ============================================================================== +# Define the full path to the shared object file +so_file_path = "/my-sglang/sglang/moe_lora_ops_debug.cpython-312-x86_64-linux-gnu.so" + +# Extract the directory containing the .so file +so_dir = os.path.dirname(so_file_path) + +# Add that DIRECTORY to sys.path +if so_dir not in sys.path: + sys.path.append(so_dir) + +moe_lora_ops = None +try: + # Python automatically handles the suffix (.cpython-312...) + import moe_lora_ops_debug as moe_lora_ops +except ImportError as e: + print(f"WARNING: Could not import 'moe_lora_ops_debug' from {so_dir}.") + print(f"Detailed error: {e}") + print("LoRA MoE functionality will fail if requested.") +# ============================================================================= + _is_hip = is_hip() _is_cuda = is_cuda() @@ -65,9 +91,7 @@ class LoRAInfo: token_ids: torch.Tensor # [num_dispatched] - original token indices expert_ids: torch.Tensor # [num_dispatched] - expert IDs lora_ids: torch.Tensor # [num_dispatched] - LoRA adapter IDs - - # Original per-token LoRA mapping (unsorted) - token_lora_indices: torch.Tensor # [num_tokens] - LoRA adapter ID for each token + token_lora_mapping: torch.Tensor # [num_tokens] - LoRA adapter ID for each token # LoRA config per adapter lora_ranks: torch.Tensor # [num_loras] @@ -217,10 +241,7 @@ def run( dtype=hidden_states.dtype, ) - # output shape: [M, top_k, N] and in original token order since we do not pass in c_sorted=True. If we - # want to get the output in sorted order by expert, we can pass in c_sorted=True. - # TODO: determine whether we should pass in c_sorted=True. That will make it less readable and different from base run method. - # but we won't have to scatter the lora delta to the correct positions before adding them to this base output. + invoke_fused_moe_kernel( hidden_states, w13, @@ -246,16 +267,76 @@ def run( block_shape=block_shape, ) + # ============================== + # Perform LoRA alignment for both gate up and gate down operations + # Define shrink_config for LoRA alignment + shrink_config = { + "BLOCK_SIZE_M": 64 + } # Default block size, can be made configurable + + # Prepare inputs for the kernel + block_size_m = shrink_config["BLOCK_SIZE_M"] + max_loras = len(lora_info.lora_ranks) + + # Calculate max_num_tokens_padded + max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * ( + block_size_m - 1 + ) + max_num_tokens_padded = ( + (max_num_tokens_padded + block_size_m - 1) // block_size_m + ) * block_size_m + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + + # Initialize output tensors (using torch.empty like the reference implementation) + device = topk_ids.device + # TODO: Jonahbernard: check if we can allocate these tensors in LoRAInfo object itself and reuse them for each layer + sorted_token_ids_lora = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=device, + ) + expert_ids_lora = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=device, + ) + num_tokens_post_padded_lora = torch.empty( + (max_loras,), dtype=torch.int32, device=device + ) + + # Get token-to-LoRA mapping from lora_info + token_lora_mapping = lora_info.token_lora_mapping + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) + + moe_lora_ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping.to(torch.int32), + int(lora_info.num_experts), + int(block_size_m), + int(max_loras), + int(max_num_tokens_padded), + int(max_num_m_blocks), + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + lora_info.adapter_enabled.to(torch.int32), + lora_ids.to(torch.int32), + None, # expert_map + ) + + # ============================== + # ============================================================ # Stage 1.5: Add LoRA gate_up delta BEFORE activation # ============================================================ self._add_lora_gate_up_delta( hidden_states=hidden_states, intermediate_cache=intermediate_cache1, - topk_ids=topk_ids, topk_weights=topk_weights, lora_info=lora_info, - num_tokens_post_padded=num_tokens_post_padded, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, ) # ============================================================ @@ -310,11 +391,7 @@ def run( else: out_hidden_states = torch.empty_like(hidden_states) - # output shape: [M, hidden_dim] and in original token order since we do not pass in c_sorted=True. If we - # want to get the output in sorted order by expert, we can pass in c_sorted=True. We use no_combine=False because the next - # LoRA computation requires input to be [M, hidden_dim] - # TODO: determine whether we should pass in c_sorted=True. That will make it less readable and different from base run method. - # but we won't have to scatter the lora delta to the correct positions before adding them to this base output. + invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -351,14 +428,13 @@ def run( self._add_lora_down_delta( intermediate_input=intermediate_cache2, intermediate_cache=intermediate_cache3, - topk_ids=topk_ids, topk_weights=topk_weights, - apply_router_weight_on_input=apply_router_weight_on_input, lora_info=lora_info, - num_tokens_post_padded=num_tokens_post_padded, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, ) - # we still need to combine the output # ============================================================ # Stage 4: Final reduction (sum across top_k) @@ -413,10 +489,11 @@ def _add_lora_gate_up_delta( self, hidden_states: torch.Tensor, # [M, hidden_dim] intermediate_cache: torch.Tensor, # [M, top_k, gate_up_dim] - topk_ids: torch.Tensor, # [M, top_k] topk_weights: torch.Tensor, # [M, top_k] lora_info: LoRAInfo, - num_tokens_post_padded: torch.Tensor, + sorted_token_ids_lora: torch.Tensor, + expert_ids_lora: torch.Tensor, + num_tokens_post_padded_lora: torch.Tensor, ) -> None: """ Add LoRA gate_up delta to intermediate_cache in-place. @@ -428,99 +505,30 @@ def _add_lora_gate_up_delta( from sglang.srt.lora.triton_ops import fused_moe_lora M, top_k, gate_up_dim = intermediate_cache.shape - num_dispatched = lora_info.token_ids.shape[0] - # Compute LoRA delta where intermediate_cache needs to be [M, top_k, gate_up_dim] in original token order. - # Output shape: [M, top_k, gate_up_dim] - # Hidden_states shape: [M, hidden_dim] (handles token duplication internally) # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return - # Create num_tokens_post_padded tensor for vLLM kernel - # It expects shape (max_loras,) with the same value for each LoRA - num_loras = len(lora_info.lora_ranks) - num_tokens_post_padded_formatted = num_tokens_post_padded.expand(num_loras) actual_max_lora_rank = lora_info.max_lora_rank lora_a_stacked = [lora_info.gate_up_lora_a_weights] lora_b_stacked = [lora_info.gate_up_lora_b_weights] - # Define shrink_config for LoRA alignment - shrink_config = { - "BLOCK_SIZE_M": 64 - } # Default block size, can be made configurable - - # Call moe_lora_align_block_size before the LoRA gate_up_proj delta path - # Prepare inputs for the kernel - block_size_m = shrink_config["BLOCK_SIZE_M"] max_loras = len(lora_info.lora_ranks) - num_tokens = M - topk_num = topk_ids.shape[1] - - # Calculate max_num_tokens_padded - max_num_tokens_padded = topk_ids.numel() + lora_info.num_experts * ( - block_size_m - 1 - ) - max_num_tokens_padded = ( - (max_num_tokens_padded + block_size_m - 1) // block_size_m - ) * block_size_m - max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m - - # Initialize output tensors (using torch.empty like the reference implementation) - device = topk_ids.device - sorted_token_ids_lora = torch.empty( - (max_loras * max_num_tokens_padded,), - dtype=torch.int32, - device=device, - ) - expert_ids_lora = torch.empty( - (max_loras * max_num_m_blocks,), - dtype=torch.int32, - device=device, - ) - num_tokens_post_padded_lora = torch.empty( - (max_loras,), dtype=torch.int32, device=device - ) - - # Use the actual token-to-LoRA mapping from the forward batch - token_lora_mapping = lora_info.token_lora_indices.to(dtype=torch.int32, device=device) - lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) - - # Call the kernel directly - torch.ops.sgl_kernel.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - lora_info.num_experts, - block_size_m, - max_loras, - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids_lora, - expert_ids_lora, - num_tokens_post_padded_lora, - lora_info.adapter_enabled, - lora_ids, - None, # expert_map, set to None for now - ) - - max_loras = len(lora_info.lora_ranks) - expert_ids_lora = expert_ids_lora.view(max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) - - # TODO: Jonahbernard: we need some way to pass these on to the down_proj delta path - lora_info.expert_ids = expert_ids_lora - lora_info.token_ids = sorted_token_ids_lora # it could be an issue overwriting these two fields here. + # Reshape the sorted tensors for fused_moe_lora (expects 2D: max_loras x max_num_tokens_padded) + sorted_token_ids_reshaped = sorted_token_ids_lora.view(max_loras, -1) + expert_ids_reshaped = expert_ids_lora.view(max_loras, -1) fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=hidden_states, lora_a_stacked=lora_a_stacked, lora_b_stacked=lora_b_stacked, - topk_weights=topk_weights, # Use actual routing weights - sorted_token_ids=sorted_token_ids_lora, - expert_ids=expert_ids_lora, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids_reshaped, + expert_ids=expert_ids_reshaped, num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, @@ -546,11 +554,11 @@ def _add_lora_down_delta( self, intermediate_input: torch.Tensor, # [M * top_k, intermediate_dim] intermediate_cache: torch.Tensor, # [M, top_k, hidden_dim] - topk_ids: torch.Tensor, # [M, top_k] topk_weights: torch.Tensor, # [M, top_k] - apply_router_weight_on_input: bool, lora_info: LoRAInfo, - num_tokens_post_padded: torch.Tensor, + sorted_token_ids_lora: torch.Tensor, + expert_ids_lora: torch.Tensor, + num_tokens_post_padded_lora: torch.Tensor, ) -> None: """ Add LoRA down delta to intermediate_cache in-place. @@ -563,23 +571,13 @@ def _add_lora_down_delta( M, top_k, hidden_dim = intermediate_cache.shape - # intermediate_input is the input from the previous stage and is in original token order. - - # Data format adaptation for vLLM kernel - num_dispatched_down = lora_info.token_ids.shape[0] # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return - # Create num_tokens_post_padded tensor for vLLM kernel - # It expects shape (max_loras,) with the same value for each LoRA - num_loras = len(lora_info.lora_ranks) actual_max_lora_rank = lora_info.max_lora_rank - # Handle multi-LoRA: stack weights for all loaded LoRAs - # lora_info.down_lora_a_weights shape: [num_loras, num_experts, max_rank, intermediate_dim] - # Note: LoRA scaling factors (lora_info.lora_scalings) are already applied to weights during loading max_loras = len(lora_info.lora_ranks) # Validate weight dimensions match expectations @@ -593,19 +591,20 @@ def _add_lora_down_delta( lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] - # TODO: Jonahbernard: is this variable needed? - num_tokens_post_padded_lora = torch.empty( - (max_loras,), dtype=torch.int32, device=topk_weights.device - ) + max_loras = len(lora_info.lora_ranks) + + # Reshape the sorted tensors for fused_moe_lora (expects 2D: max_loras x max_num_tokens_padded) + sorted_token_ids_reshaped = sorted_token_ids_lora.view(max_loras, -1) + expert_ids_reshaped = expert_ids_lora.view(max_loras, -1) fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, lora_a_stacked=lora_a_stacked, lora_b_stacked=lora_b_stacked, - topk_weights=topk_weights, # Use the routing weights passed to this function - sorted_token_ids=lora_info.token_ids, # this is the token_ids from the previous stage - expert_ids=lora_info.expert_ids, # this is the expert_ids from the previous stage + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids_reshaped, + expert_ids=expert_ids_reshaped, num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, @@ -625,4 +624,5 @@ def _add_lora_down_delta( expand_num_warps=4, expand_num_stages=2, expand_split_k=1, + mul_routed_weight=True, ) From 49ac7124214423421e37733a0dca2cc3c8fc8d59 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 3 Feb 2026 19:25:29 -0500 Subject: [PATCH 080/150] fix small issue in layers.py --- python/sglang/srt/lora/layers.py | 37 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index d5b70148411e..ea7a31d1318d 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -602,7 +602,6 @@ def __init__( self.down_lora_a_weights = None self.down_lora_b_weights = None self._lora_runner = None - self.max_lora_rank = 0 # Will be set by LoRAManager def set_lora_info( self, @@ -631,7 +630,6 @@ def _get_lora_info( return None from sglang.srt.lora.lora_moe_runners import LoRAInfo - from sglang.srt.lora.moe_dispatch import moe_dispatch # Get dispatch info from TopKOutput topk_ids = topk_output.topk_ids # [num_tokens, top_k] @@ -642,24 +640,26 @@ def _get_lora_info( lora_ranks = batch_info.lora_ranks # [num_loras] scalings = batch_info.scalings # [num_loras] - # Use global max LoRA rank set by LoRAManager - max_lora_rank = self.max_lora_rank - - # Create adapter_enabled tensor for the current batch - # All LoRAs in the batch are enabled by definition - adapter_enabled = torch.ones( - len(lora_ranks), dtype=torch.int32, device=lora_ranks.device - ) + # Compute max LoRA rank from current batch ranks + max_lora_rank = int(torch.max(lora_ranks)) # Use precomputed per-token LoRA indices from forward batch lora_indices = self.lora_backend.forward_batch.token_lora_indices - # Dispatch tokens to experts - token_ids, expert_ids, _, lora_ids = moe_dispatch( - topk_ids=topk_ids, - topk_weights=topk_weights, - lora_indices=lora_indices, - ) + # Create adapter_enabled tensor for the current batch + # Only enable LoRA adapters that are actually used in this batch + # TODO: Jonahbernard: check that this doesn't slow down inference for this batch + adapter_enabled = torch.zeros(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) + unique_lora_ids = torch.unique(lora_indices) + adapter_enabled[unique_lora_ids] = 1 + + # TODO: Jonahbernard: check if this is correct + # Flatten dispatch info (no longer using moe_dispatch) + num_tokens, top_k = topk_ids.shape + device = topk_ids.device + token_ids = torch.arange(num_tokens, device=device, dtype=torch.int32).repeat_interleave(top_k) + expert_ids = topk_ids.flatten().to(torch.int32) + lora_ids = lora_indices.repeat_interleave(top_k) return LoRAInfo( gate_up_lora_a_weights=self.gate_up_lora_a_weights, @@ -669,7 +669,7 @@ def _get_lora_info( token_ids=token_ids, expert_ids=expert_ids, lora_ids=lora_ids, - token_lora_indices=lora_indices, + token_lora_mapping=lora_indices, lora_ranks=lora_ranks, lora_scalings=scalings, adapter_enabled=adapter_enabled, @@ -724,11 +724,10 @@ def _forward_with_lora( # Create LoRA-enabled MoeRunner if not already created if self._lora_runner is None: from sglang.srt.layers.moe.moe_runner.runner import MoeRunner - self._lora_runner = MoeRunner( base_layer.quant_method.runner.runner_backend, base_layer.moe_runner_config, - lora_enabled=True, + lora_enabled=True ) # Build quant info (for unquantized, this is straightforward) From 468a10f2045ae7f6f62056a8ddb9dd0fbf370a9e Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 3 Feb 2026 19:26:57 -0500 Subject: [PATCH 081/150] remove custom kernel build path --- python/sglang/srt/lora/lora_moe_runners.py | 26 ---------------------- 1 file changed, 26 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 1c67af71b791..0b0d32b0542a 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -38,32 +38,6 @@ ) from sglang.srt.utils import is_cuda, is_hip -import sys -import os -# ============================================================================== -# IMPORT PREBUILT MOE LORA KERNEL -# IMPORT PREBUILT MOE LORA KERNEL -# ============================================================================== -# Define the full path to the shared object file -so_file_path = "/my-sglang/sglang/moe_lora_ops_debug.cpython-312-x86_64-linux-gnu.so" - -# Extract the directory containing the .so file -so_dir = os.path.dirname(so_file_path) - -# Add that DIRECTORY to sys.path -if so_dir not in sys.path: - sys.path.append(so_dir) - -moe_lora_ops = None -try: - # Python automatically handles the suffix (.cpython-312...) - import moe_lora_ops_debug as moe_lora_ops -except ImportError as e: - print(f"WARNING: Could not import 'moe_lora_ops_debug' from {so_dir}.") - print(f"Detailed error: {e}") - print("LoRA MoE functionality will fail if requested.") -# ============================================================================= - _is_hip = is_hip() _is_cuda = is_cuda() From e8259a3a9d4c3c67bf627397734c5289a991d646 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 3 Feb 2026 19:29:41 -0500 Subject: [PATCH 082/150] remove unused code --- python/sglang/srt/lora/moe_dispatch.py | 56 --- .../lora/triton_ops/per_expert_lora_moe.py | 351 ------------------ 2 files changed, 407 deletions(-) delete mode 100644 python/sglang/srt/lora/moe_dispatch.py delete mode 100644 python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py diff --git a/python/sglang/srt/lora/moe_dispatch.py b/python/sglang/srt/lora/moe_dispatch.py deleted file mode 100644 index de7106a489b9..000000000000 --- a/python/sglang/srt/lora/moe_dispatch.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""MoE dispatch utilities.""" - -import torch - - -def moe_dispatch( - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - lora_indices: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Dispatch tokens to experts for MoE computation. - - Args: - topk_ids: [num_tokens, top_k] - Expert IDs selected by router - topk_weights: [num_tokens, top_k] - Router weights - lora_indices: [num_tokens] - LoRA adapter ID for each token - - Returns: - sorted_token_ids: Token indices sorted by expert_id - sorted_expert_ids: Corresponding expert IDs - sorted_topk_weights: Corresponding router weights - sorted_lora_ids: LoRA adapter IDs for each dispatched token - """ - num_tokens, top_k = topk_ids.shape - device = topk_ids.device - - # Flatten topk dimensions: [num_tokens * top_k] - flat_topk_ids = topk_ids.flatten() - flat_topk_weights = topk_weights.flatten() - flat_token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k) - flat_lora_ids = lora_indices.repeat_interleave(top_k) - - # Sort by expert_id only (each expert uses same LoRA adapter logic) - sorted_indices = torch.argsort(flat_topk_ids) - - sorted_token_ids = flat_token_ids[sorted_indices] - sorted_expert_ids = flat_topk_ids[sorted_indices] - sorted_topk_weights = flat_topk_weights[sorted_indices] - sorted_lora_ids = flat_lora_ids[sorted_indices] - - return sorted_token_ids, sorted_expert_ids, sorted_topk_weights, sorted_lora_ids diff --git a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py b/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py deleted file mode 100644 index 5537c3aa89ec..000000000000 --- a/python/sglang/srt/lora/triton_ops/per_expert_lora_moe.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Per-expert LoRA computation kernel for MoE layers.""" - -import torch -import triton -import triton.language as tl - - -@triton.jit -def _per_expert_lora_kernel( - # Input/Output pointers - hidden_states_ptr, # [num_total_tokens, input_dim] - lora_a_weights_ptr, # [num_loras, num_experts, max_rank, input_dim] - lora_b_weights_ptr, # [num_loras, num_experts, output_dim, max_rank] - output_ptr, # [num_total_tokens, output_dim] - base output (modified in-place) - lora_output_ptr, # [num_total_tokens, output_dim] - separate LoRA-only output - # Dispatch info (length = num_dispatched) - token_ids_ptr, # [num_dispatched] -> index into hidden/output - expert_ids_ptr, # [num_dispatched] - lora_ids_ptr, # [num_dispatched] - # Dimensions - input_dim: tl.constexpr, - output_dim: tl.constexpr, - max_rank: tl.constexpr, - num_experts: tl.constexpr, - num_dispatched, - # Strides for 4D LoRA A weights [num_loras, num_experts, max_rank, input_dim] - lora_a_stride_lora: tl.constexpr, - lora_a_stride_expert: tl.constexpr, - lora_a_stride_rank: tl.constexpr, - lora_a_stride_input: tl.constexpr, - # Strides for 4D LoRA B weights [num_loras, num_experts, output_dim, max_rank] - lora_b_stride_lora: tl.constexpr, - lora_b_stride_expert: tl.constexpr, - lora_b_stride_output: tl.constexpr, - lora_b_stride_rank: tl.constexpr, - # LoRA ranks per adapter [num_loras] - lora_ranks_ptr, - # Scaling factors per adapter [num_loras] - lora_scalings_ptr, - # Block size (used for input and output tiling; rank is not tiled) - BLOCK_SIZE: tl.constexpr, - # Whether this is down_proj (affects stacking factor for rank calculation) - IS_DOWN_PROJ: tl.constexpr, -): - """ - Compute per-expert LoRA delta: - - delta[token, out_slice, lora] = B[out_slice, :] @ (A @ hidden_states[token]) - - 3D Grid: (spatial, slices, loras) - - spatial = program_id(0): dispatched token index - - slices = program_id(1): tile index along output_dim - - loras = program_id(2): LoRA adapter index - """ - - # 3D grid indices - spatial_id = tl.program_id(0) # dispatched token index - slice_id = tl.program_id(1) # output slice index - lora_id_grid = tl.program_id(2) # LoRA adapter index - - # Bounds check on dispatched tokens - if spatial_id >= num_dispatched: - return - - # Load dispatch info for this dispatched index - actual_token_id = tl.load(token_ids_ptr + spatial_id) - expert_id = tl.load(expert_ids_ptr + spatial_id) - token_lora_id = tl.load(lora_ids_ptr + spatial_id) - - # Skip if this token does not use this LoRA adapter - if token_lora_id != lora_id_grid: - return - - # Load LoRA rank and scaling (scalar tensors) for this LoRA adapter - rank = tl.load(lora_ranks_ptr + lora_id_grid) - scaling = tl.load(lora_scalings_ptr + lora_id_grid) - - # Adjust rank for stacked modules (gate_up_proj has stacking factor 2) - effective_rank = rank - if not IS_DOWN_PROJ: # gate_up_proj case - effective_rank = rank * 2 - - has_rank = effective_rank > 0 - if not has_rank: - return - - # ---------------------------- - # Base pointers - # ---------------------------- - # hidden_states[actual_token_id, :] - hidden_ptr = hidden_states_ptr + actual_token_id * input_dim - - # A[lora_id_grid, expert_id, :, :] - lora_a_base = ( - lora_a_weights_ptr - + lora_id_grid * lora_a_stride_lora - + expert_id * lora_a_stride_expert - ) - - # B[lora_id_grid, expert_id, :, :] - lora_b_base = ( - lora_b_weights_ptr - + lora_id_grid * lora_b_stride_lora - + expert_id * lora_b_stride_expert - ) - - # ---------------------------- - # Stage 1: intermediate = A @ hidden - # ---------------------------- - - # We assume max_rank is small enough to keep as a single 1D vector - r_offs = tl.arange(0, max_rank) # [max_rank] - rank_mask = r_offs < effective_rank # [max_rank] - - # TODO (Jonahcb): check if it is better to allocate outside the kernel - # Accumulator for intermediate: [max_rank] - intermediate = tl.zeros((max_rank,), dtype=tl.float32) - - # Tile over input_dim in chunks of BLOCK_SIZE - NUM_INPUT_TILES = (input_dim + BLOCK_SIZE - 1) // BLOCK_SIZE - for input_tile_idx in range(NUM_INPUT_TILES): - input_start = input_tile_idx * BLOCK_SIZE - input_offs = input_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] - input_mask = input_offs < input_dim # [BLOCK_SIZE] - - # Load input values for this tile: [BLOCK_SIZE] - h_vals = tl.load( - hidden_ptr + input_offs, - mask=input_mask, - other=0.0, - ).to(tl.float32) - - # Build [max_rank, BLOCK_SIZE] tile of A: - # rows: r_offs - # cols: input_offs - # offset = base + r * stride_rank + h * stride_input - a_ptrs = ( - lora_a_base - + r_offs[:, None] * lora_a_stride_rank - + input_offs[None, :] - * lora_a_stride_input # check if it is necessary to multiply by stride value as it should be contigious in this dimension - ) - a_vals = tl.load( - a_ptrs, - mask=rank_mask[:, None] & input_mask[None, :], - other=0.0, - ).to(tl.float32) - - # Dot over hidden axis: [max_rank] - # intermediate[r] += sum_h A[r, h] * h_vals[h] - intermediate += tl.sum(a_vals * h_vals[None, :], axis=1) - - # ---------------------------- - # Stage 2: y_slice = B[out_slice, :] @ intermediate - # One output slice per program along output_dim. - # ---------------------------- - out_start = slice_id * BLOCK_SIZE - out_offs = out_start + tl.arange(0, BLOCK_SIZE) # [BLOCK_SIZE] - out_mask = out_offs < output_dim # [BLOCK_SIZE] - - # Build [max_rank, BLOCK_SIZE] tile of B: - # rows: r_offs (rank dimension) - # cols: out_offs (output dimension) - # offset = base + out * stride_output + r * stride_rank - b_ptrs = ( - lora_b_base - + out_offs[None, :] * lora_b_stride_output - + r_offs[:, None] * lora_b_stride_rank - ) - b_vals = tl.load( - b_ptrs, - mask=rank_mask[:, None] & out_mask[None, :], - other=0.0, - ).to(tl.float32) - - # out_vals[j] = sum_r B[j, r] * intermediate[r] - out_vals = tl.sum(b_vals * intermediate[:, None], axis=0) # [BLOCK_SIZE] - - # Apply scaling - out_vals *= scaling - - # Router weights are applied in the final reduction step, not in the kernel - - # ---------------------------- - # Store results for each (token, expert) pair separately - # ---------------------------- - # Convert to output dtype (matches hidden_states dtype, could be float16/bfloat16/float32) - out_vals_typed = out_vals.to(output_ptr.dtype.element_ty) - - # Write to spatial_id position (each (token, expert) pair gets its own row) - out_row_base = spatial_id * output_dim - out_ptrs = output_ptr + out_row_base + out_offs - lora_out_ptrs = lora_output_ptr + out_row_base + out_offs - # Use regular store since each spatial_id is unique - tl.store(out_ptrs, out_vals_typed, out_mask) - tl.store(lora_out_ptrs, out_vals_typed, out_mask) - - -def per_expert_lora_forward( - hidden_states: torch.Tensor, - lora_a_weights: torch.Tensor, - lora_b_weights: torch.Tensor, - token_ids: torch.Tensor, - expert_ids: torch.Tensor, - lora_ids: torch.Tensor, - lora_ranks: torch.Tensor, - lora_scalings: torch.Tensor, - num_experts: int, - base_output: torch.Tensor = None, - is_down_proj: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass for per-expert LoRA computation using a 3D Triton grid: - grid = (spatial, slices, loras) - - Mathematically correct implementation that keeps expert outputs separate - until final reduction, matching the base Triton MoE pattern. - - Args: - hidden_states: [num_tokens, input_dim] where input_dim is hidden_dim for gate_up_proj - or intermediate_dim for down_proj - lora_a_weights: [num_loras, num_experts, max_rank, input_dim] - lora_b_weights: [num_loras, num_experts, output_dim, max_rank] - token_ids: [num_dispatched] - Original token indices - expert_ids: [num_dispatched] - Expert ID for each dispatched token - lora_ids: [num_dispatched] - LoRA ID for each dispatched token - lora_ranks: [num_loras] - Rank for each LoRA - lora_scalings: [num_loras] - Scaling factor for each LoRA - num_experts: Total number of experts - base_output: Output tensor with shape [num_dispatched, output_dim] - Each row contains the output for one (token, expert) pair - is_down_proj: Whether this is for down_proj (intermediate_dim -> hidden_dim) - or gate_up_proj (hidden_dim -> intermediate_dim) - - Returns: - tuple of: - output: LoRA delta for each (token, expert) pair - lora_output: Just the LoRA delta contribution (same as output) - """ - - # Shapes - num_tokens, input_dim = hidden_states.shape - num_loras, _, output_dim, _ = lora_b_weights.shape - num_dispatched = token_ids.shape[0] - - # Use fixed max_rank for consistent kernel compilation - # Maximum stacking factor is 2 (for gate_up_proj), so max_rank = max_lora_rank * 2 - # We assume max_lora_rank is reasonably small (e.g., 64-128) so max_rank = 256 is safe - max_rank = 256 # Conservative upper bound for max_lora_rank * 2 - - # Make sure everything is on the same device and contiguous - device = hidden_states.device - - # Use hidden_states dtype for consistency with model - dtype = hidden_states.dtype - hidden_states = hidden_states.contiguous() - lora_a_weights = lora_a_weights.contiguous() - lora_b_weights = lora_b_weights.contiguous() - token_ids = token_ids.contiguous() - expert_ids = expert_ids.contiguous() - lora_ids = lora_ids.contiguous() - lora_ranks = lora_ranks.contiguous() - lora_scalings = lora_scalings.contiguous() - - # Router weights are always applied in the final reduction step, never in the kernel - - # Always keep experts separate until final reduction - output_shape = (num_dispatched, output_dim) - - # Initialize or reuse output tensor for in-place addition - if base_output is None: - # Use specified dtype for consistency with model - output = torch.zeros( - *output_shape, - dtype=dtype, - device=device, - ) - else: - output = base_output - assert ( - output.shape == output_shape - ), f"Expected shape {output_shape}, got {output.shape}" - assert output.device == device - - # Allocate separate tensor for just the LoRA contribution - lora_output = torch.zeros( - *output_shape, - dtype=dtype, - device=device, - ) - - # Tile size for hidden and output dimensions - BLOCK_SIZE = 64 # tune as needed - - # Number of output slices along output_dim - num_slices = (output_dim + BLOCK_SIZE - 1) // BLOCK_SIZE - - # 3D grid: (spatial, slices, loras) - grid = (num_dispatched, num_slices, num_loras) - - _per_expert_lora_kernel[grid]( - # Pointers - hidden_states, # hidden_states_ptr - lora_a_weights, # lora_a_weights_ptr - lora_b_weights, # lora_b_weights_ptr - output, # output_ptr (base output, modified in-place) - lora_output, # lora_output_ptr (separate LoRA-only output) - # Dispatch info - token_ids, # token_ids_ptr - expert_ids, # expert_ids_ptr - lora_ids, # lora_ids_ptr - # Dimensions - input_dim, # input_dim (hidden_dim for gate_up_proj, intermediate_dim for down_proj) - output_dim, # output_dim (intermediate_dim for gate_up_proj, hidden_dim for down_proj) - max_rank, # max_rank - num_experts, # num_experts - num_dispatched, # num_dispatched (runtime scalar) - # LoRA A strides: [num_loras, num_experts, max_rank, input_dim] - lora_a_weights.stride(0), - lora_a_weights.stride(1), - lora_a_weights.stride(2), - lora_a_weights.stride(3), - # LoRA B strides: [num_loras, num_experts, output_dim, max_rank] - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), - lora_b_weights.stride(3), - # Rank & scaling - lora_ranks, # lora_ranks_ptr - lora_scalings, # lora_scalings_ptr - # Block size (constexpr) - BLOCK_SIZE=BLOCK_SIZE, - # Whether this is down_proj - IS_DOWN_PROJ=is_down_proj, - ) - - return output, lora_output From 999dd7cfc15348d9c3344c2ff5bb4ed61cddaa0d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 4 Feb 2026 16:24:55 -0500 Subject: [PATCH 083/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 67 ++++++++-------------- 1 file changed, 25 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 0b0d32b0542a..5b1cd1ba3754 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -38,6 +38,7 @@ ) from sglang.srt.utils import is_cuda, is_hip + _is_hip = is_hip() _is_cuda = is_cuda() @@ -61,15 +62,11 @@ class LoRAInfo: ) # [num_loras, num_experts, max_rank, intermediate_dim] down_lora_b_weights: torch.Tensor # [num_loras, num_experts, hidden_dim, max_rank] - # Dispatch info (sorted by expert) - token_ids: torch.Tensor # [num_dispatched] - original token indices - expert_ids: torch.Tensor # [num_dispatched] - expert IDs - lora_ids: torch.Tensor # [num_dispatched] - LoRA adapter IDs - token_lora_mapping: torch.Tensor # [num_tokens] - LoRA adapter ID for each token + # Per-token LoRA adapter ID [num_tokens] + token_lora_mapping: torch.Tensor # LoRA config per adapter lora_ranks: torch.Tensor # [num_loras] - lora_scalings: torch.Tensor # [num_loras] adapter_enabled: torch.Tensor # [num_loras] - which adapters are enabled max_lora_rank: int # Maximum LoRA rank across all adapters @@ -215,7 +212,6 @@ def run( dtype=hidden_states.dtype, ) - invoke_fused_moe_kernel( hidden_states, w13, @@ -263,7 +259,6 @@ def run( # Initialize output tensors (using torch.empty like the reference implementation) device = topk_ids.device - # TODO: Jonahbernard: check if we can allocate these tensors in LoRAInfo object itself and reuse them for each layer sorted_token_ids_lora = torch.empty( (max_loras * max_num_tokens_padded,), dtype=torch.int32, @@ -284,7 +279,7 @@ def run( moe_lora_ops.moe_lora_align_block_size( topk_ids, - token_lora_mapping.to(torch.int32), + token_lora_mapping, int(lora_info.num_experts), int(block_size_m), int(max_loras), @@ -293,12 +288,14 @@ def run( sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora, - lora_info.adapter_enabled.to(torch.int32), - lora_ids.to(torch.int32), + lora_info.adapter_enabled, + lora_ids, None, # expert_map ) - # ============================== + # Reshape the sorted tensors for fused_moe_lora (expects 2D: max_loras x max_num_tokens_padded) + sorted_token_ids_reshaped = sorted_token_ids_lora.view(max_loras, -1) + expert_ids_reshaped = expert_ids_lora.view(max_loras, -1) # ============================================================ # Stage 1.5: Add LoRA gate_up delta BEFORE activation @@ -308,8 +305,8 @@ def run( intermediate_cache=intermediate_cache1, topk_weights=topk_weights, lora_info=lora_info, - sorted_token_ids_lora=sorted_token_ids_lora, - expert_ids_lora=expert_ids_lora, + sorted_token_ids_reshaped=sorted_token_ids_reshaped, + expert_ids_reshaped=expert_ids_reshaped, num_tokens_post_padded_lora=num_tokens_post_padded_lora, ) @@ -365,7 +362,6 @@ def run( else: out_hidden_states = torch.empty_like(hidden_states) - invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -404,12 +400,11 @@ def run( intermediate_cache=intermediate_cache3, topk_weights=topk_weights, lora_info=lora_info, - sorted_token_ids_lora=sorted_token_ids_lora, - expert_ids_lora=expert_ids_lora, + sorted_token_ids_reshaped=sorted_token_ids_reshaped, + expert_ids_reshaped=expert_ids_reshaped, num_tokens_post_padded_lora=num_tokens_post_padded_lora, ) - # ============================================================ # Stage 4: Final reduction (sum across top_k) # ============================================================ @@ -465,8 +460,8 @@ def _add_lora_gate_up_delta( intermediate_cache: torch.Tensor, # [M, top_k, gate_up_dim] topk_weights: torch.Tensor, # [M, top_k] lora_info: LoRAInfo, - sorted_token_ids_lora: torch.Tensor, - expert_ids_lora: torch.Tensor, + sorted_token_ids_reshaped: torch.Tensor, + expert_ids_reshaped: torch.Tensor, num_tokens_post_padded_lora: torch.Tensor, ) -> None: """ @@ -480,7 +475,6 @@ def _add_lora_gate_up_delta( M, top_k, gate_up_dim = intermediate_cache.shape - # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return @@ -491,9 +485,10 @@ def _add_lora_gate_up_delta( lora_b_stacked = [lora_info.gate_up_lora_b_weights] max_loras = len(lora_info.lora_ranks) - # Reshape the sorted tensors for fused_moe_lora (expects 2D: max_loras x max_num_tokens_padded) - sorted_token_ids_reshaped = sorted_token_ids_lora.view(max_loras, -1) - expert_ids_reshaped = expert_ids_lora.view(max_loras, -1) + + lora_ids = torch.arange( + max_loras, dtype=torch.int32, device=hidden_states.device + ) fused_moe_lora( output=intermediate_cache, @@ -506,7 +501,7 @@ def _add_lora_gate_up_delta( num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, - lora_ids=lora_info.lora_ids, + lora_ids=lora_ids, adapter_enabled=lora_info.adapter_enabled, shrink_block_size_m=64, shrink_block_size_n=64, @@ -530,8 +525,8 @@ def _add_lora_down_delta( intermediate_cache: torch.Tensor, # [M, top_k, hidden_dim] topk_weights: torch.Tensor, # [M, top_k] lora_info: LoRAInfo, - sorted_token_ids_lora: torch.Tensor, - expert_ids_lora: torch.Tensor, + sorted_token_ids_reshaped: torch.Tensor, + expert_ids_reshaped: torch.Tensor, num_tokens_post_padded_lora: torch.Tensor, ) -> None: """ @@ -545,7 +540,6 @@ def _add_lora_down_delta( M, top_k, hidden_dim = intermediate_cache.shape - # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return @@ -554,22 +548,11 @@ def _add_lora_down_delta( max_loras = len(lora_info.lora_ranks) - # Validate weight dimensions match expectations - assert ( - lora_info.down_lora_a_weights.shape[0] == max_loras - ), f"Expected {max_loras} LoRAs, got {lora_info.down_lora_a_weights.shape[0]}" - assert ( - lora_info.adapter_enabled.shape[0] >= max_loras - ), f"adapter_enabled too small: {lora_info.adapter_enabled.shape[0]} < {max_loras}" - lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] - max_loras = len(lora_info.lora_ranks) - - # Reshape the sorted tensors for fused_moe_lora (expects 2D: max_loras x max_num_tokens_padded) - sorted_token_ids_reshaped = sorted_token_ids_lora.view(max_loras, -1) - expert_ids_reshaped = expert_ids_lora.view(max_loras, -1) + device = intermediate_cache.device + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) fused_moe_lora( output=intermediate_cache, @@ -582,7 +565,7 @@ def _add_lora_down_delta( num_tokens_post_padded=num_tokens_post_padded_lora, max_lora_rank=actual_max_lora_rank, top_k_num=top_k, - lora_ids=lora_info.lora_ids, + lora_ids=lora_ids, adapter_enabled=lora_info.adapter_enabled, shrink_block_size_m=64, shrink_block_size_n=64, From 5fcdfd51f1a6dbc6bb65d1f1747d31e1ed3966f5 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 4 Feb 2026 16:25:37 -0500 Subject: [PATCH 084/150] fix --- python/sglang/srt/lora/layers.py | 100 +++++++++++-------------------- 1 file changed, 34 insertions(+), 66 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index ea7a31d1318d..2de5f8b670bc 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -592,16 +592,33 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): def __init__( self, - base_layer: nn.Module, + base_layer: FusedMoE, lora_backend: BaseLoRABackend, ): + # initializes FusedMoE with its own moe_runner for base path super().__init__(base_layer, lora_backend) # LoRA tensors will be set by LoRAManager self.gate_up_lora_a_weights = None self.gate_up_lora_b_weights = None self.down_lora_a_weights = None self.down_lora_b_weights = None - self._lora_runner = None + + # initialize triton_lora moe runner for batches with lora enabled + from sglang.srt.layers.moe.moe_runner.runner import MoeRunner + from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + self._lora_runner = MoeRunner( + base_layer.quant_method.runner.runner_backend, + base_layer.moe_runner_config, + lora_enabled=True + ) + + # Pre-compute quant info for efficiency (weights don't change during inference) + self._quant_info = TritonMoeQuantInfo( + w13_weight=base_layer.w13_weight, + w2_weight=base_layer.w2_weight, + b13=getattr(base_layer, "w13_weight_bias", None), + b2=getattr(base_layer, "w2_weight_bias", None), + ) def set_lora_info( self, @@ -626,52 +643,36 @@ def _get_lora_info( Returns None if LoRA is not enabled or weights are not set. """ - if not self.set_lora or self.gate_up_lora_a_weights is None: - return None - from sglang.srt.lora.lora_moe_runners import LoRAInfo - # Get dispatch info from TopKOutput - topk_ids = topk_output.topk_ids # [num_tokens, top_k] - topk_weights = topk_output.topk_weights # [num_tokens, top_k] - # Get LoRA batch info from backend batch_info = self.lora_backend.batch_info lora_ranks = batch_info.lora_ranks # [num_loras] - scalings = batch_info.scalings # [num_loras] # Compute max LoRA rank from current batch ranks - max_lora_rank = int(torch.max(lora_ranks)) + if hasattr(batch_info, "max_lora_rank") and batch_info.max_lora_rank is not None: + max_lora_rank = batch_info.max_lora_rank + else: + max_lora_rank = int(torch.max(lora_ranks)) - # Use precomputed per-token LoRA indices from forward batch - lora_indices = self.lora_backend.forward_batch.token_lora_indices + # Use precomputed per-token LoRA indices from forward batch (int32 for kernel use) + lora_indices = self.lora_backend.forward_batch.token_lora_indices.to( + torch.int32 + ) # Create adapter_enabled tensor for the current batch # Only enable LoRA adapters that are actually used in this batch # TODO: Jonahbernard: check that this doesn't slow down inference for this batch adapter_enabled = torch.zeros(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) - unique_lora_ids = torch.unique(lora_indices) - adapter_enabled[unique_lora_ids] = 1 - - # TODO: Jonahbernard: check if this is correct - # Flatten dispatch info (no longer using moe_dispatch) - num_tokens, top_k = topk_ids.shape - device = topk_ids.device - token_ids = torch.arange(num_tokens, device=device, dtype=torch.int32).repeat_interleave(top_k) - expert_ids = topk_ids.flatten().to(torch.int32) - lora_ids = lora_indices.repeat_interleave(top_k) + adapter_enabled.index_fill_(0, lora_indices.long(), 1) return LoRAInfo( gate_up_lora_a_weights=self.gate_up_lora_a_weights, gate_up_lora_b_weights=self.gate_up_lora_b_weights, down_lora_a_weights=self.down_lora_a_weights, down_lora_b_weights=self.down_lora_b_weights, - token_ids=token_ids, - expert_ids=expert_ids, - lora_ids=lora_ids, token_lora_mapping=lora_indices, lora_ranks=lora_ranks, - lora_scalings=scalings, adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, num_experts=self.base_layer.num_experts, @@ -685,15 +686,11 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs 1. After gate_up projection, before activation 2. After down projection, before final reduction """ - # If LoRA is not enabled, just run base MoE - if not self.set_lora or self.gate_up_lora_a_weights is None: - return self.base_layer.forward(hidden_states, topk_output, **kwargs) # Build LoRA info for this batch lora_info = self._get_lora_info(topk_output) - # For now, we use the integrated runner approach only for Triton backend - # This wraps the base layer's forward with LoRA integration + # run lora moe_runner return self._forward_with_lora(hidden_states, topk_output, lora_info, **kwargs) def _forward_with_lora( @@ -705,54 +702,25 @@ def _forward_with_lora( ): """ Run MoE forward with LoRA integration at the correct points. - - This method hooks into the base layer's computation to add LoRA deltas - at the right stages. """ - from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA - from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo - # Get the base layer's dispatch and combine logic base_layer = self.base_layer - origin_hidden_states_dim = hidden_states.shape[-1] - # Dispatch tokens + # Dispatch tokens (doesn't do much in the LoRA case) dispatch_output = base_layer.dispatcher.dispatch( hidden_states=hidden_states, topk_output=topk_output ) - # Create LoRA-enabled MoeRunner if not already created - if self._lora_runner is None: - from sglang.srt.layers.moe.moe_runner.runner import MoeRunner - self._lora_runner = MoeRunner( - base_layer.quant_method.runner.runner_backend, - base_layer.moe_runner_config, - lora_enabled=True - ) + # Use pre-computed quant info (doesn't change so not sure why we need to pass it in every time) + quant_info = self._quant_info - # Build quant info (for unquantized, this is straightforward) - quant_info = TritonMoeQuantInfo( - w13_weight=base_layer.w13_weight, - w2_weight=base_layer.w2_weight, - b13=getattr(base_layer, "w13_weight_bias", None), - b2=getattr(base_layer, "w2_weight_bias", None), - ) - - # Run with LoRA integration using the MoeRunner infrastructure + # Run the only lora moe runner (Triton) combine_input = self._lora_runner.run( dispatch_output, quant_info, lora_info=lora_info ) - # Combine and return + # Combine and return (doesn't do much in the LoRA case) final_hidden_states = base_layer.dispatcher.combine(combine_input=combine_input) - final_hidden_states = final_hidden_states[ - ..., :origin_hidden_states_dim - ].contiguous() - - if base_layer.reduce_results and ( - base_layer.moe_tp_size > 1 or base_layer.moe_ep_size > 1 - ): - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states From f7cba2505f71dd6f15b45559475b07c0ec400fa5 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 4 Feb 2026 16:35:21 -0500 Subject: [PATCH 085/150] major fixes --- python/sglang/srt/lora/lora_manager.py | 3 +++ python/sglang/srt/lora/mem_pool.py | 7 ++++++- python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | 2 ++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 638591215aad..246e81b55a01 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -283,6 +283,9 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): use_cuda_graph=use_cuda_graph, ) + # Attach max_lora_rank to batch_info for MoE usage + self.lora_backend.batch_info.max_lora_rank = max(lora_ranks) if lora_ranks else 0 + # Populate per-token LoRA indices from segment information batch_info = self.lora_backend.batch_info num_tokens = forward_batch.input_ids.shape[0] # Tokens in current forward pass diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index f143a8361efe..519ce438928e 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -628,7 +628,12 @@ def load_lora_weight_tensor( for expert_id, expert_weight in weights.items(): # Buffer shape: [num_loras, num_experts, intermediate_dim, max_rank] buffer_view = target_buffer[buffer_id, expert_id, :, :lora_rank] - load_lora_weight_tensor(buffer_view, expert_weight) + + weight_to_load = expert_weight + if weight_to_load is not None: + weight_to_load = weight_to_load * lora_adapter.scaling + + load_lora_weight_tensor(buffer_view, weight_to_load) else: # Standard: single tensor per module buffer_view = target_buffer[buffer_id, :, :lora_rank] diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index e7d730c6bb70..fb27dbf5caa7 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -94,6 +94,8 @@ def _fused_moe_lora_kernel( launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, ): + # TODO (Jonahcb): investigate why GDC is not working + USE_GDC = False pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) From 688e0c2dcd9d956a2c38601b4eb198190e828c30 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 4 Feb 2026 16:50:06 -0500 Subject: [PATCH 086/150] fix --- .../csrc/moe/moe_lora_align_sum_kernel.cu | 751 +++++++++++------- 1 file changed, 445 insertions(+), 306 deletions(-) diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu index 1203217eb754..06f16cbf9253 100644 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -1,11 +1,11 @@ // Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu +// TODO (Jonahcb): merge with moe_align_kernel.cu -#include #include #include +#include #include - #include #include "utils.h" @@ -19,21 +19,21 @@ namespace batched_moe_align_block_size { static constexpr int32_t num_threads = 1024; static constexpr int32_t num_blocks = 1; __global__ void batched_moe_align_block_size_kernel( - int32_t const num_batches, int32_t const max_tokens_per_batch, - int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, - int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, + int32_t const num_batches, + int32_t const max_tokens_per_batch, + int32_t const block_size, + int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ block_ids, int32_t* __restrict__ num_tokens_post_pad) { // TODO(varun): This is a naive implementation. Could be optimized. size_t const batch_id = threadIdx.x; size_t const stride = blockDim.x * gridDim.x; - int32_t const num_blocks_per_batch = - CEILDIV(max_tokens_per_batch, block_size); - int32_t const sorted_ids_size = - num_blocks_per_batch * num_batches * block_size; + int32_t const num_blocks_per_batch = CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = num_blocks_per_batch * num_batches * block_size; int32_t const block_ids_size = sorted_ids_size / block_size; - int32_t const SENTINEL = - num_batches * max_tokens_per_batch; // To denote invalid entries. + int32_t const SENTINEL = num_batches * max_tokens_per_batch; // To denote invalid entries. // Intialize sorted_ids for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { sorted_ids[i] = SENTINEL; @@ -47,8 +47,7 @@ __global__ void batched_moe_align_block_size_kernel( if (batch_id < num_batches) { b_num_tokens = batch_num_tokens[batch_id]; } - int32_t const ceil_b_num_tokens = - CEILDIV(b_num_tokens, block_size) * block_size; + int32_t const ceil_b_num_tokens = CEILDIV(b_num_tokens, block_size) * block_size; // Compute prefix sum over token counts per expert using BlockScan = cub::BlockScan; @@ -80,13 +79,23 @@ __global__ void batched_moe_align_block_size_kernel( template __device__ void _moe_align_block_size( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, - int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, - int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id, - int32_t topk_num, int32_t* token_mask, bool has_expert_map) { + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t padded_num_experts, + int32_t experts_per_warp, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, + int32_t model_offset, + int32_t inactive_expert_id, + int32_t topk_num, + int32_t* token_mask, + bool has_expert_map) { extern __shared__ int32_t shared_counts[]; // Compute input buffer offsets. Typically these will all be 0, except when @@ -99,8 +108,7 @@ __device__ void _moe_align_block_size( // This is safe since the current kernel does not use sorted_token_ids. if (blockIdx.x % 2) { // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; - it += blockDim.x) { + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { sorted_token_ids[sorted_token_ids_offset + it] = numel; } return; @@ -133,8 +141,7 @@ __device__ void _moe_align_block_size( int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], - mask); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], mask); } __syncthreads(); @@ -165,15 +172,13 @@ __device__ void _moe_align_block_size( __syncthreads(); if (threadIdx.x < num_experts) { - for (int i = cumsum[cumsum_offset + threadIdx.x]; - i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + for (int i = cumsum[cumsum_offset + threadIdx.x]; i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; } } // Fill remaining expert_ids with 0 - const size_t fill_start_idx = - cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { expert_ids[expert_ids_offset + i] = inactive_expert_id; } @@ -182,12 +187,20 @@ __device__ void _moe_align_block_size( template __device__ void _moe_align_block_size_small_batch_expert( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, - size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks, - int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num, - int32_t* token_mask, bool has_expert_map) { + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, + int32_t inactive_expert_id, + int32_t model_offset, + int32_t topk_num, + int32_t* token_mask, + bool has_expert_map) { // Compute input buffer offsets. Typically these will all be 0, except when // using Multi LoRA. int sorted_token_ids_offset = max_num_tokens_padded * model_offset; @@ -199,8 +212,7 @@ __device__ void _moe_align_block_size_small_batch_expert( // synchronization easier. if (threadIdx.x < fill_threads) { // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; - it += fill_threads) { + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) { sorted_token_ids[sorted_token_ids_offset + it] = numel; } // Three __syncthreads() corresponding to the other threads @@ -237,8 +249,7 @@ __device__ void _moe_align_block_size_small_batch_expert( if (tid < num_experts) { tokens_cnts[tid] = 0; for (int i = 1; i <= stride; ++i) { - tokens_cnts[i * num_experts + tid] += - tokens_cnts[(i - 1) * num_experts + tid]; + tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; } } @@ -247,13 +258,9 @@ __device__ void _moe_align_block_size_small_batch_expert( if (tid == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = - cumsum[i - 1] + - CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * - block_size; + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size; } - total_tokens_post_pad[model_offset] = - static_cast(cumsum[num_experts]); + total_tokens_post_pad[model_offset] = static_cast(cumsum[num_experts]); } __syncthreads(); @@ -277,8 +284,7 @@ __device__ void _moe_align_block_size_small_batch_expert( // filter invalid expert if (expert_id == -1) continue; } - int32_t rank_post_pad = - tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; if (token_mask == nullptr || token_mask[i / topk_num]) { sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; @@ -290,10 +296,16 @@ __device__ void _moe_align_block_size_small_batch_expert( template __device__ void _count_and_sort_expert_tokens( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, - int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask, - int32_t model_offset, int32_t topk_num, bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t* __restrict__ token_mask, + int32_t model_offset, + int32_t topk_num, + bool has_expert_map) { const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.y; @@ -310,10 +322,8 @@ __device__ void _count_and_sort_expert_tokens( } if (token_mask == nullptr || token_mask[i / topk_num]) { - int32_t rank_post_pad = atomicAdd( - &cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); - sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = - i; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i; } } } @@ -321,28 +331,63 @@ __device__ void _count_and_sort_expert_tokens( template __global__ void moe_align_block_size_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, - int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, - int32_t topk_num, bool has_expert_map) { + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t padded_num_experts, + int32_t experts_per_warp, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + int32_t max_num_tokens_padded, + int32_t topk_num, + bool has_expert_map) { _moe_align_block_size( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, padded_num_experts, experts_per_warp, block_size, numel, - cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), - 0, 0, topk_num, nullptr, has_expert_map); + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + numel, + cumsum, + max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), + 0, + 0, + topk_num, + nullptr, + has_expert_map); } template __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, - int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t topk_num, + bool has_expert_map) { _count_and_sort_expert_tokens( - topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, - max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map); + topk_ids, + sorted_token_ids, + cumsum_buffer, + expert_map, + numel, + num_experts, + max_num_tokens_padded, + nullptr, + 0, + topk_num, + has_expert_map); } template @@ -364,29 +409,56 @@ __global__ void moe_sum_kernel( template __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, - size_t numel, int32_t max_num_tokens_padded, int32_t topk_num, + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t max_num_tokens_padded, + int32_t topk_num, bool has_expert_map) { _moe_align_block_size_small_batch_expert( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, block_size, numel, max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + block_size, + numel, + max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), + 0, + 0, + topk_num, + nullptr, has_expert_map); } template __global__ void moe_lora_align_block_size_kernel( - scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping, - int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, - int max_loras, size_t numel, int max_num_tokens_padded, - int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, int32_t topk_num, - int32_t* total_tokens_post_pad, int32_t* adapter_enabled, - int32_t* __restrict__ cumsum, int32_t experts_per_warp, - int32_t padded_num_experts, int32_t* lora_ids, - int32_t* __restrict__ token_mask, bool has_expert_map) { + scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ token_lora_mapping, + int64_t block_size, + int32_t* __restrict__ expert_map, + int num_experts, + int max_loras, + size_t numel, + int max_num_tokens_padded, + int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t topk_num, + int32_t* total_tokens_post_pad, + int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, + int32_t experts_per_warp, + int32_t padded_num_experts, + int32_t* lora_ids, + int32_t* __restrict__ token_mask, + bool has_expert_map) { int lora_idx = blockIdx.x / 2; int lora_id = lora_ids[lora_idx]; if (lora_id == -1 || adapter_enabled[lora_id] == 0) { @@ -399,27 +471,46 @@ __global__ void moe_lora_align_block_size_kernel( total_tokens_post_pad[lora_id] = 0; for (int i = 0; i < num_tokens; i++) { - token_mask[(lora_id * num_tokens) + i] = - (int)token_lora_mapping[i] == lora_id; + token_mask[(lora_id * num_tokens) + i] = (int)token_lora_mapping[i] == lora_id; } } __syncthreads(); _moe_align_block_size( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, padded_num_experts, experts_per_warp, block_size, numel, - cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num, - &token_mask[(lora_id * num_tokens)], has_expert_map); + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + numel, + cumsum, + max_num_tokens_padded, + max_num_m_blocks, + lora_id, + -1, + topk_num, + &token_mask[(lora_id * num_tokens)], + has_expert_map); } template __global__ void lora_count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, - int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask, - int32_t* lora_ids, bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t topk_num, + int32_t* token_mask, + int32_t* lora_ids, + bool has_expert_map) { int lora_idx = blockIdx.x; int lora_id = lora_ids[lora_idx]; if (lora_id == -1) { @@ -429,20 +520,38 @@ __global__ void lora_count_and_sort_expert_tokens_kernel( int num_tokens = numel / topk_num; _count_and_sort_expert_tokens( - topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, - max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id, - topk_num, has_expert_map); + topk_ids, + sorted_token_ids, + cumsum_buffer, + expert_map, + numel, + num_experts, + max_num_tokens_padded, + &token_mask[(lora_id * num_tokens)], + lora_id, + topk_num, + has_expert_map); } template __global__ void moe_lora_align_block_size_small_batch_expert_kernel( - scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, - int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, - int max_loras, size_t numel, int max_num_tokens_padded, - int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, int topk_num, - int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids, - int32_t* token_mask, bool has_expert_map) { + scalar_t* __restrict__ topk_ids, + int32_t* token_lora_mapping, + int64_t block_size, + int32_t* __restrict__ expert_map, + int num_experts, + int max_loras, + size_t numel, + int max_num_tokens_padded, + int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int topk_num, + int32_t* total_tokens_post_pad, + int32_t* adapter_enabled, + int32_t* lora_ids, + int32_t* token_mask, + bool has_expert_map) { int lora_idx = blockIdx.x; int lora_id = lora_ids[lora_idx]; if (lora_id == -1 || adapter_enabled[lora_id] == 0) { @@ -454,17 +563,27 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( total_tokens_post_pad[lora_id] = 0; for (int i = 0; i < num_tokens; i++) { - token_mask[(lora_id * num_tokens) + i] = - (int)token_lora_mapping[i] == lora_id; + token_mask[(lora_id * num_tokens) + i] = (int)token_lora_mapping[i] == lora_id; } } __syncthreads(); _moe_align_block_size_small_batch_expert( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks, - -1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)], + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + block_size, + numel, + max_num_tokens_padded, + max_num_m_blocks, + -1, + lora_id, + topk_num, + &token_mask[(lora_id * num_tokens)], has_expert_map); } @@ -472,24 +591,24 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( // taken from // https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad, - std::optional maybe_expert_map) { +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = - ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; int experts_per_warp = WARP_SIZE; int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, - "padded_num_experts must be less than 1024"); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); + auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); bool has_expert_map = maybe_expert_map.has_value(); torch::Tensor expert_map; if (has_expert_map) { @@ -498,86 +617,89 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, expert_map = torch::empty({0}, options_int); } - DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `cumsum` tensors - bool small_batch_expert_mode = - (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t threads = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem_size = - ((threads + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - auto small_batch_expert_kernel = - moe::moe_align_block_size_small_batch_expert_kernel< - scalar_t, fill_threads>; - small_batch_expert_kernel<<<1, fill_threads + threads, - shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), num_experts, block_size, - topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1), - has_expert_map); - } else { - torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); - auto align_kernel = moe::moe_align_block_size_kernel; - - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); - size_t shared_mem_size = - num_warps * experts_per_warp * sizeof(int32_t); - - // launch two threadblocks - // blockIdx.x == 0: counting experts and aligning - // blockIdx.x == 1: filling sorted_token_ids - align_kernel<<<2, threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), num_experts, padded_num_experts, - experts_per_warp, block_size, topk_ids.numel(), - cumsum_buffer.data_ptr(), sorted_token_ids.size(0), - topk_ids.size(1), has_expert_map); - - const int block_threads = std::min(256, (int)threads); - const int num_blocks = - (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - dim3 gridDims(1, actual_blocks); - - auto sort_kernel = - moe::count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), expert_map.data_ptr(), - topk_ids.numel(), num_experts, sorted_token_ids.size(0), - topk_ids.size(1), has_expert_map); - } - }); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + auto small_batch_expert_kernel = moe::moe_align_block_size_small_batch_expert_kernel; + small_batch_expert_kernel<<<1, fill_threads + threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + sorted_token_ids.size(0), + topk_ids.size(1), + has_expert_map); + } else { + torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int); + auto align_kernel = moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); + + // launch two threadblocks + // blockIdx.x == 0: counting experts and aligning + // blockIdx.x == 1: filling sorted_token_ids + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr(), + sorted_token_ids.size(0), + topk_ids.size(1), + has_expert_map); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + dim3 gridDims(1, actual_blocks); + + auto sort_kernel = moe::count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + expert_map.data_ptr(), + topk_ids.numel(), + num_experts, + sorted_token_ids.size(0), + topk_ids.size(1), + has_expert_map); + } + }); } -void batched_moe_align_block_size(int64_t max_tokens_per_batch, - int64_t block_size, - torch::Tensor const& batch_num_tokens, - torch::Tensor sorted_ids, - torch::Tensor batch_ids, - torch::Tensor num_tokens_post_pad) { +void batched_moe_align_block_size( + int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { namespace batched_kernel = moe::batched_moe_align_block_size; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int32_t const B = batch_num_tokens.size(0); - int32_t const num_blocks_per_batch = - round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks_per_batch = round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; int32_t const num_blocks = num_blocks_per_batch * B; int64_t const sorted_ids_size = num_blocks * block_size; @@ -586,15 +708,20 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, TORCH_CHECK(num_tokens_post_pad.size(0) == 1); TORCH_CHECK(B <= batched_kernel::num_threads); - batched_kernel::batched_moe_align_block_size_kernel<<< - batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( - B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr(), - sorted_ids.data_ptr(), batch_ids.data_ptr(), - num_tokens_post_pad.data_ptr()); + batched_kernel:: + batched_moe_align_block_size_kernel<<>>( + B, + max_tokens_per_batch, + block_size, + batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), + batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr()); } -void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] - torch::Tensor& output) // [num_tokens, hidden_size] +void moe_sum( + torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); const auto num_tokens = output.numel() / hidden_size; @@ -608,25 +735,22 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] switch (topk) { case 2: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel<<>>( - output.data_ptr(), input.data_ptr(), - hidden_size); + moe::moe_sum_kernel + <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 3: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel<<>>( - output.data_ptr(), input.data_ptr(), - hidden_size); + moe::moe_sum_kernel + <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 4: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel<<>>( - output.data_ptr(), input.data_ptr(), - hidden_size); + moe::moe_sum_kernel + <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; @@ -637,33 +761,35 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] } void moe_lora_align_block_size( - torch::Tensor topk_ids, torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, int64_t max_loras, - int64_t max_num_tokens_padded, int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, - torch::Tensor lora_ids, std::optional maybe_expert_map) { + torch::Tensor topk_ids, + torch::Tensor token_lora_mapping, + int64_t num_experts, + int64_t block_size, + int64_t max_loras, + int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor adapter_enabled, + torch::Tensor lora_ids, + std::optional maybe_expert_map) { const int topk_num = topk_ids.size(1); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); int device_max_shared_mem; auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = - ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, - "padded_num_experts must be less than 1024"); + TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor token_mask = - torch::empty({max_loras * topk_ids.size(0)}, options_int); + auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor token_mask = torch::empty({max_loras * topk_ids.size(0)}, options_int); bool has_expert_map = maybe_expert_map.has_value(); torch::Tensor expert_map; if (has_expert_map) { @@ -672,93 +798,106 @@ void moe_lora_align_block_size( expert_map = torch::empty({0}, options_int); } - DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { - bool small_batch_expert_mode = - (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t num_thread = max((int32_t)num_experts, 128); - const int32_t shared_mem = - (num_thread + 1) * num_experts * sizeof(int32_t) + - (num_experts + 1) * sizeof(int32_t); - if (shared_mem > device_max_shared_mem) { - TORCH_CHECK(false, "Shared memory usage exceeds device limit."); - } - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - - dim3 blockDim(num_thread + fill_threads); - auto kernel = - moe::moe_lora_align_block_size_small_batch_expert_kernel< - scalar_t, fill_threads>; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); - kernel<<>>( - topk_ids.data_ptr(), - token_lora_mapping.data_ptr(), block_size, - expert_map.data_ptr(), num_experts, max_loras, - topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), lora_ids.data_ptr(), - token_mask.data_ptr(), has_expert_map); - } else { - int num_thread = 1024; - dim3 blockDim(num_thread); - size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); - - size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); - - // cumsum buffer - torch::Tensor cumsum = - torch::zeros({max_loras * (num_experts + 1)}, options_int); - - auto align_kernel = - moe::moe_lora_align_block_size_kernel; - - // launch two threadblocks for each lora - // blockIdx.x % 2 == 0: counting experts and aligning - // blockIdx.x % 2 == 1: filling sorted_token_ids - align_kernel<<>>( - topk_ids.data_ptr(), - token_lora_mapping.data_ptr(), block_size, - expert_map.data_ptr(), num_experts, max_loras, - topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), cumsum.data_ptr(), - WARP_SIZE, padded_num_experts, lora_ids.data_ptr(), - token_mask.data_ptr(), has_expert_map); - - const int block_threads = std::min(256, (int)num_thread); - const int num_blocks = - (topk_ids.numel() + block_threads - 1) / block_threads; - - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - dim3 gridDims(max_loras, actual_blocks); - auto sort_kernel = - moe::lora_count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), cumsum.data_ptr(), - expert_map.data_ptr(), topk_ids.numel(), num_experts, - max_num_tokens_padded, topk_num, token_mask.data_ptr(), - lora_ids.data_ptr(), has_expert_map); - } - }); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = max((int32_t)num_experts, 128); + const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, "Shared memory usage exceeds device limit."); + } + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); + kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), + block_size, + expert_map.data_ptr(), + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), + lora_ids.data_ptr(), + token_mask.data_ptr(), + has_expert_map); + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + // cumsum buffer + torch::Tensor cumsum = torch::zeros({max_loras * (num_experts + 1)}, options_int); + + auto align_kernel = moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + align_kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), + block_size, + expert_map.data_ptr(), + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), + cumsum.data_ptr(), + WARP_SIZE, + padded_num_experts, + lora_ids.data_ptr(), + token_mask.data_ptr(), + has_expert_map); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel; + + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum.data_ptr(), + expert_map.data_ptr(), + topk_ids.numel(), + num_experts, + max_num_tokens_padded, + topk_num, + token_mask.data_ptr(), + lora_ids.data_ptr(), + has_expert_map); + } + }); } // TODO: Jonahbernard: remove this later #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("moe_lora_align_block_size", &moe_lora_align_block_size, - "MoE LoRA Align Block Size"); + m.def("moe_lora_align_block_size", &moe_lora_align_block_size, "MoE LoRA Align Block Size"); } From 66885cdbe22e4dbae266e87d9573387c2e48f716 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Wed, 4 Feb 2026 16:59:11 -0500 Subject: [PATCH 087/150] fix test --- test/manual/lora/test_lora_moe_runner.py | 250 +++++++++++++++-------- 1 file changed, 168 insertions(+), 82 deletions(-) diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py index b240d2308fa6..0942ca029c79 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/manual/lora/test_lora_moe_runner.py @@ -50,7 +50,9 @@ def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): return token_lora_mapping -def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32): +def assign_experts_to_tokens( + num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32 +): assert top_k_num <= num_experts, "top_k_num must be <= num_experts" expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) @@ -64,37 +66,87 @@ def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int, return expert_indices, expert_weights -def sample_data(num_tokens: int, num_sequences: int, max_loras: int, num_experts: int, top_k_num: int, dtype=torch.float32): - topk_ids, topk_weights = assign_experts_to_tokens(num_tokens, num_experts, top_k_num, dtype) +def sample_data( + num_tokens: int, + num_sequences: int, + max_loras: int, + num_experts: int, + top_k_num: int, + dtype=torch.float32, +): + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num, dtype + ) token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) return topk_ids, topk_weights, token_lora_mapping -def create_lora_info(token_lora_mapping, topk_ids, max_loras, num_experts, max_lora_rank, hidden_dim, intermediate_dim, gate_up_dim, dtype, device): - gate_up_lora_a_weights = torch.randn((max_loras, num_experts, max_lora_rank, hidden_dim), dtype=dtype, device=device) - gate_up_lora_b_weights = torch.randn((max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device) - down_lora_a_weights = torch.randn((max_loras, num_experts, max_lora_rank, intermediate_dim), dtype=dtype, device=device) - down_lora_b_weights = torch.randn((max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device) +def create_lora_info( + token_lora_mapping, + topk_ids, + max_loras, + num_experts, + max_lora_rank, + hidden_dim, + intermediate_dim, + gate_up_dim, + dtype, + device, +): + # ------------------------------------------------------------------------- + # 1. Deterministic LoRA A Initialization + # ------------------------------------------------------------------------- + # We fill A with (1 / input_dim). + # If input is all 1s, A @ x will result in a vector of all 1s. + + val_gate_up_a = 1.0 / hidden_dim + gate_up_lora_a_weights = torch.full( + (max_loras, num_experts, max_lora_rank, hidden_dim), + val_gate_up_a, + dtype=dtype, + device=device, + ) + + val_down_a = 1.0 / intermediate_dim + down_lora_a_weights = torch.full( + (max_loras, num_experts, max_lora_rank, intermediate_dim), + val_down_a, + dtype=dtype, + device=device, + ) + + # ------------------------------------------------------------------------- + # 2. Deterministic LoRA B Initialization (Expert-Specific) + # ------------------------------------------------------------------------- + # We want the output to be safe but noticeable. + # Let's target a base value of 0.1 per expert index. + # Formula: fill_value = (target / rank) * (expert_id + 1) - num_tokens = token_lora_mapping.shape[0] - dispatched_tokens = [] - dispatched_experts = [] - dispatched_loras = [] + base_target = 0.01 # Small enough to not explode SiLU, big enough to see - for token_idx in range(num_tokens): - lora_id = token_lora_mapping[token_idx] - for k in range(topk_ids.shape[1]): - expert_id = topk_ids[token_idx, k] - dispatched_tokens.append(token_idx) - dispatched_experts.append(expert_id) - dispatched_loras.append(lora_id) + gate_up_lora_b_weights = torch.zeros( + (max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device + ) + down_lora_b_weights = torch.zeros( + (max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device + ) + + for i in range(num_experts): + # Make every expert add a slightly different value so we can check routing + # Expert 0 adds ~0.01, Expert 10 adds ~0.11 + expert_multiplier = i + 1 + + fill_val = (base_target * expert_multiplier) / max_lora_rank - token_ids = torch.tensor(dispatched_tokens, dtype=torch.int32, device=device) - expert_ids = torch.tensor(dispatched_experts, dtype=torch.int32, device=device) - lora_ids = torch.tensor(dispatched_loras, dtype=torch.int32, device=device) + gate_up_lora_b_weights[:, i, :, :] = fill_val + down_lora_b_weights[:, i, :, :] = fill_val - lora_ranks = torch.full((max_loras,), max_lora_rank, dtype=torch.int32, device=device) - lora_scalings = torch.ones(max_loras, dtype=dtype, device=device) + # ------------------------------------------------------------------------- + # 3. Setup Metadata + # ------------------------------------------------------------------------- + lora_ranks = torch.full( + (max_loras,), max_lora_rank, dtype=torch.int32, device=device + ) adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) return LoRAInfo( @@ -102,30 +154,38 @@ def create_lora_info(token_lora_mapping, topk_ids, max_loras, num_experts, max_l gate_up_lora_b_weights=gate_up_lora_b_weights, down_lora_a_weights=down_lora_a_weights, down_lora_b_weights=down_lora_b_weights, - token_ids=token_ids, - expert_ids=expert_ids, - lora_ids=lora_ids, + token_lora_mapping=token_lora_mapping, lora_ranks=lora_ranks, - lora_scalings=lora_scalings, adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, num_experts=num_experts, ) -def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info): +def torch_naive_moe_with_lora( + hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info +): num_tokens, hidden_dim = hidden_states.shape top_k = topk_ids.shape[1] num_experts = w13.shape[0] intermediate_dim = w2.shape[2] - hidden_expanded = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) - gate_up_out = torch.zeros(num_tokens * top_k, w13.shape[1], dtype=hidden_states.dtype, device=hidden_states.device) + hidden_expanded = ( + hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) + ) + + gate_up_out = torch.zeros( + num_tokens * top_k, + w13.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) for expert_id in range(num_experts): mask = (topk_ids == expert_id).flatten() if mask.any(): - gate_up_out[mask] = hidden_expanded[mask] @ w13[expert_id].T + expert_result = hidden_expanded[mask] @ w13[expert_id].T + gate_up_out[mask] = expert_result if b13 is not None: gate_up_out[mask] += b13[expert_id] @@ -135,25 +195,38 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = lora_info.lora_ids[i * top_k + k] + lora_id = lora_info.token_lora_mapping[i] lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] - lora_delta = lora_info.lora_scalings[lora_id] * (lora_b @ (lora_a @ hidden_states[i])) + lora_a_result = lora_a @ hidden_states[i] + lora_b_result = lora_b @ lora_a_result + # Using scaling factor of 1.0 since lora_scalings was removed + lora_delta = 1.0 * lora_b_result gate_up_out[i, k] += lora_delta gate_up_dim = gate_up_out.shape[-1] gate_dim = gate_up_dim // 2 gate = gate_up_out[..., :gate_dim] up = gate_up_out[..., gate_dim:] - intermediate_out = torch.nn.functional.silu(gate) * up - down_out = torch.zeros(num_tokens, top_k, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device) + silu_gate = torch.nn.functional.silu(gate) + + intermediate_out = silu_gate * up + + down_out = torch.zeros( + num_tokens, + top_k, + hidden_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) for expert_id in range(num_experts): - mask = (topk_ids == expert_id) + mask = topk_ids == expert_id if mask.any(): masked_intermediate = intermediate_out[mask] - down_out[mask] = masked_intermediate @ w2[expert_id].T + expert_down_result = masked_intermediate @ w2[expert_id].T + down_out[mask] = expert_down_result if b2 is not None: down_out[mask] += b2[expert_id] @@ -161,59 +234,59 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = lora_info.lora_ids[i * top_k + k] + lora_id = lora_info.token_lora_mapping[i] lora_a = lora_info.down_lora_a_weights[lora_id, expert_id] lora_b = lora_info.down_lora_b_weights[lora_id, expert_id] - lora_delta = lora_info.lora_scalings[lora_id] * (lora_b @ (lora_a @ intermediate_out[i, k])) + lora_a_result = lora_a @ intermediate_out[i, k] + lora_b_result = lora_b @ lora_a_result + # Using scaling factor of 1.0 since lora_scalings was removed + lora_delta = 1.0 * lora_b_result down_out[i, k] += lora_delta weighted_out = down_out * topk_weights.unsqueeze(-1) + final_out = weighted_out.sum(dim=1) return final_out -DTYPES = [torch.float16, torch.bfloat16] -DEVICES = ["cuda:0"] -SEED = [42] - - -@pytest.mark.parametrize("num_tokens", [32]) -@pytest.mark.parametrize("top_k_num", [2, 4]) -@pytest.mark.parametrize("num_experts", [8]) -@pytest.mark.parametrize("max_loras", [2, 4]) -@pytest.mark.parametrize("hidden_dim", [512]) -@pytest.mark.parametrize("intermediate_dim", [1024]) -@pytest.mark.parametrize("max_lora_rank", [16, 32]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("seed", SEED) -def test_lora_moe_runner( - num_tokens, - top_k_num, - num_experts, - max_loras, - hidden_dim, - intermediate_dim, - max_lora_rank, - dtype, - device, - seed, +@pytest.mark.parametrize("num_tokens", [32, 64]) +@pytest.mark.parametrize("top_k_num", [1, 2]) +@pytest.mark.parametrize("num_experts", [8, 20]) +@pytest.mark.parametrize("max_lora_rank", [8, 16]) +def test_lora_moe_runner_multi_expert( + num_tokens, top_k_num, num_experts, max_lora_rank ): + # Fixed parameters + max_loras = 2 + hidden_dim = 512 + intermediate_dim = 1024 + + dtype = torch.float32 + device = "cuda:0" + seed = 42 + torch.set_default_device(device) set_random_seed(seed) - num_sequences = 4 - topk_ids, topk_weights, token_lora_mapping = sample_data( - num_tokens, num_sequences, max_loras, num_experts, top_k_num, dtype + # Distribute tokens across all experts (Random Routing) + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num, dtype ) + # Assign LoRAs randomly + num_sequences = 4 + token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) + gate_up_dim = intermediate_dim * 2 - w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) - w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) - b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) - b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) + # Initialize ALL experts with non-zero random weights + w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) * 0.01 + w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) * 0.01 + b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) * 0.01 + b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) * 0.01 + + # Set input to random values hidden_states = torch.randn(num_tokens, hidden_dim, dtype=dtype) lora_info = create_lora_info( @@ -229,10 +302,16 @@ def test_lora_moe_runner( device=device, ) + # Sort tokens by expert ID for the runner + topk_ids_flat = topk_ids.flatten() + sorted_indices = torch.argsort(topk_ids_flat) + sorted_token_ids = sorted_indices // top_k_num + expert_ids = topk_ids_flat[sorted_indices] + num_dispatched = num_tokens * top_k_num - sorted_token_ids = torch.arange(num_dispatched, dtype=torch.int32, device=device) - expert_ids = topk_ids.flatten().to(dtype=torch.int32, device=device) - num_tokens_post_padded = torch.tensor([num_dispatched], dtype=torch.int32, device=device) + num_tokens_post_padded = torch.tensor( + [num_dispatched], dtype=torch.int32, device=device + ) runner_input = TritonRunnerInput( hidden_states=hidden_states, @@ -262,8 +341,8 @@ def test_lora_moe_runner( num_local_experts=num_experts, ) - # Create StandardTopKOutput for DispatchOutput - router_logits = torch.randn(num_tokens, num_experts, dtype=dtype, device=device) # Dummy logits + # Create StandardTopKOutput + router_logits = torch.randn(num_tokens, num_experts, dtype=dtype, device=device) topk_output = StandardTopKOutput( topk_weights=topk_weights, topk_ids=topk_ids, @@ -277,12 +356,13 @@ def test_lora_moe_runner( topk_output=topk_output, ) - # Test the full MoeRunner flow with LoRA enabled - # Mock global server args to avoid dependency on server initialization class MockServerArgs: enable_deterministic_inference = False - with patch('sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args', return_value=MockServerArgs()): + with patch( + "sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args", + return_value=MockServerArgs(), + ): runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) combine_input = runner.run(dispatch_output, quant_info, lora_info) lora_output = combine_input @@ -291,4 +371,10 @@ class MockServerArgs: hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info ) - torch.testing.assert_close(lora_output.hidden_states, torch_output, atol=1e-1, rtol=1e-1) + print(f"lora_output.hidden_states mean: {lora_output.hidden_states.mean()}") + print(f"torch_output mean: {torch_output.mean()}") + + # Assert close + torch.testing.assert_close( + lora_output.hidden_states, torch_output, atol=1e-2, rtol=1e-2 + ) From f5cf61558c710177093335c1febbffd65cdf9412 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Thu, 5 Feb 2026 11:02:35 -0500 Subject: [PATCH 088/150] remove csgmv support --- python/sglang/srt/lora/layers.py | 6 ++++++ python/sglang/srt/lora/lora_manager.py | 17 ++--------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 2de5f8b670bc..4821bb75a259 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -604,6 +604,12 @@ def __init__( self.down_lora_b_weights = None # initialize triton_lora moe runner for batches with lora enabled + if lora_backend.name != "triton": + raise ValueError( + "FusedMoEWithLoRA only supports 'triton' backend. " + "Please set --lora-backend triton when using LoRA on MoE models." + ) + from sglang.srt.layers.moe.moe_runner.runner import MoeRunner from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo self._lora_runner = MoeRunner( diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 246e81b55a01..105ca618ab16 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -291,7 +291,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): num_tokens = forward_batch.input_ids.shape[0] # Tokens in current forward pass # Create tensor and fill with adapter indices from segments - token_lora_indices_reordered = torch.empty( + token_lora_indices = torch.empty( num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device ) seg_indptr = batch_info.seg_indptr # [num_segments + 1] @@ -299,20 +299,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): start_token = seg_indptr[seg_idx] end_token = seg_indptr[seg_idx + 1] lora_adapter = batch_info.weight_indices[seg_idx] - token_lora_indices_reordered[start_token:end_token] = lora_adapter - - if batch_info.permutation is None: - # No reordering (e.g., triton backend): segments are in original order - token_lora_indices = token_lora_indices_reordered - else: - # Tokens are reordered (chunked backend): need to convert back to original order - inverse_permutation = torch.empty_like(batch_info.permutation) - inverse_permutation[batch_info.permutation] = torch.arange( - num_tokens, - dtype=batch_info.permutation.dtype, - device=batch_info.permutation.device, - ) - token_lora_indices = token_lora_indices_reordered[inverse_permutation] + token_lora_indices[start_token:end_token] = lora_adapter forward_batch.token_lora_indices = token_lora_indices From dccc35913b554933e0f507086edbbf00a5c3a91d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 6 Feb 2026 11:01:58 -0500 Subject: [PATCH 089/150] fixes --- python/sglang/srt/lora/layers.py | 34 +- python/sglang/srt/lora/lora_manager.py | 26 +- .../lora/triton_ops/fused_moe_lora_kernel.py | 33 +- .../csrc/moe/moe_lora_align_sum_kernel.cu | 860 +++++++++--------- 4 files changed, 455 insertions(+), 498 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 4821bb75a259..16eb2c97976e 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -595,6 +595,14 @@ def __init__( base_layer: FusedMoE, lora_backend: BaseLoRABackend, ): + + # initialize triton_lora moe runner for batches with lora enabled + if lora_backend.name != "triton": + raise ValueError( + "FusedMoEWithLoRA only supports 'triton' backend. " + "Please set --lora-backend triton when using LoRA on MoE models." + ) + # initializes FusedMoE with its own moe_runner for base path super().__init__(base_layer, lora_backend) # LoRA tensors will be set by LoRAManager @@ -604,12 +612,6 @@ def __init__( self.down_lora_b_weights = None # initialize triton_lora moe runner for batches with lora enabled - if lora_backend.name != "triton": - raise ValueError( - "FusedMoEWithLoRA only supports 'triton' backend. " - "Please set --lora-backend triton when using LoRA on MoE models." - ) - from sglang.srt.layers.moe.moe_runner.runner import MoeRunner from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo self._lora_runner = MoeRunner( @@ -641,8 +643,7 @@ def set_lora_info( self.down_lora_b_weights = down_lora_b_weights def _get_lora_info( - self, - topk_output: TopKOutput, + self ): """ Build LoRAInfo for the current batch. @@ -656,28 +657,21 @@ def _get_lora_info( lora_ranks = batch_info.lora_ranks # [num_loras] # Compute max LoRA rank from current batch ranks - if hasattr(batch_info, "max_lora_rank") and batch_info.max_lora_rank is not None: - max_lora_rank = batch_info.max_lora_rank - else: - max_lora_rank = int(torch.max(lora_ranks)) - - # Use precomputed per-token LoRA indices from forward batch (int32 for kernel use) - lora_indices = self.lora_backend.forward_batch.token_lora_indices.to( - torch.int32 - ) + max_lora_rank = self.lora_backend.max_lora_rank # Create adapter_enabled tensor for the current batch # Only enable LoRA adapters that are actually used in this batch # TODO: Jonahbernard: check that this doesn't slow down inference for this batch adapter_enabled = torch.zeros(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) - adapter_enabled.index_fill_(0, lora_indices.long(), 1) + adapter_enabled.index_fill_(0, batch_info.weight_indices.long(), 1) return LoRAInfo( gate_up_lora_a_weights=self.gate_up_lora_a_weights, gate_up_lora_b_weights=self.gate_up_lora_b_weights, down_lora_a_weights=self.down_lora_a_weights, down_lora_b_weights=self.down_lora_b_weights, - token_lora_mapping=lora_indices, + seg_ind = batch_info.seg_indptr, + req_to_lora = batch_info.weight_indices, lora_ranks=lora_ranks, adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, @@ -694,7 +688,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs """ # Build LoRA info for this batch - lora_info = self._get_lora_info(topk_output) + lora_info = self._get_lora_info() # run lora moe_runner return self._forward_with_lora(hidden_states, topk_output, lora_info, **kwargs) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 105ca618ab16..75d3c24ce2c9 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -283,28 +283,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): use_cuda_graph=use_cuda_graph, ) - # Attach max_lora_rank to batch_info for MoE usage - self.lora_backend.batch_info.max_lora_rank = max(lora_ranks) if lora_ranks else 0 - - # Populate per-token LoRA indices from segment information - batch_info = self.lora_backend.batch_info - num_tokens = forward_batch.input_ids.shape[0] # Tokens in current forward pass - - # Create tensor and fill with adapter indices from segments - token_lora_indices = torch.empty( - num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device - ) - seg_indptr = batch_info.seg_indptr # [num_segments + 1] - for seg_idx in range(batch_info.num_segments): - start_token = seg_indptr[seg_idx] - end_token = seg_indptr[seg_idx + 1] - lora_adapter = batch_info.weight_indices[seg_idx] - token_lora_indices[start_token:end_token] = lora_adapter - - forward_batch.token_lora_indices = token_lora_indices - - # Store forward_batch reference in backend for MoE layer access - self.lora_backend.forward_batch = forward_batch def update_lora_info(self): """ @@ -399,6 +377,10 @@ def init_state( max_lora_rank=max_lora_rank, target_modules=target_modules, ) + + # Inject max_lora_rank into backend + self.lora_backend.max_lora_rank = self.max_lora_rank + self.init_lora_modules() self.init_memory_pool() self.update_lora_info() diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index fb27dbf5caa7..d11c1c7f5205 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -94,13 +94,14 @@ def _fused_moe_lora_kernel( launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, ): - # TODO (Jonahcb): investigate why GDC is not working - USE_GDC = False pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) + USE_GDC = False # TODO (Jonahcb): remove this + + if lora_id == -1: # Early exit for the no-lora case. return @@ -132,6 +133,8 @@ def _fused_moe_lora_kernel( expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) if expert_id == -1: return + + # get a_ptr,b_ptr,c_ptr cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) @@ -139,6 +142,8 @@ def _fused_moe_lora_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + # ================================================================= secure + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) token_ind = stride_tl * lora_id + offs_token_id @@ -147,6 +152,9 @@ def _fused_moe_lora_kernel( ) token_mask = offs_token < num_valid_tokens + # ================================================================= secure + + # get a_ptrs,b_ptrs a_ptrs = cur_a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak @@ -160,18 +168,30 @@ def _fused_moe_lora_kernel( + offs_bn[None, :] * stride_bn ) + + if USE_GDC and IS_PRIMARY: # GDC launch dependents hints the runtime system to launch dependent kernels. tl.extra.cuda.gdc_launch_dependents() + # ================================================================= secure + # accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # ================================================================= secure + + + + # GDC wait waits for ALL programs in the prior kernel to complete # before continuing. if USE_GDC and not IS_PRIMARY: tl.extra.cuda.gdc_wait() + + for k in range(0, grid_k): k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) # pre-fetch lora weight @@ -235,7 +255,10 @@ def _fused_moe_lora_shrink( mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] - use_gdc = is_arch_support_pdl() + + # TODO (Jonahcb): investigate why relying on is_arch_support_pdl() is causing crash inside kernel + #use_gdc = is_arch_support_pdl() + use_gdc = False shrink_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, @@ -329,6 +352,7 @@ def _fused_moe_lora_expand( mul_routed_weight: bool = False, offset: int = 0, ) -> None: + b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank N = w1_output_dim_size @@ -439,6 +463,8 @@ def _fused_moe_lora( == qcurr_hidden_states.dim() == 2 ) + if sorted_token_ids.shape[0] != expert_ids.shape[0] or sorted_token_ids.shape[0] != num_tokens_post_padded.shape[0]: + x = 1 assert ( sorted_token_ids.shape[0] == expert_ids.shape[0] @@ -469,6 +495,7 @@ def _fused_moe_lora( device=device, ) + _fused_moe_lora_shrink( a_intermediate_cache1, qcurr_hidden_states, diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu index 06f16cbf9253..e350f9108448 100644 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -1,14 +1,66 @@ // Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu -// TODO (Jonahcb): merge with moe_align_kernel.cu +#include #include #include -#include #include + #include -#include "utils.h" +// ================================================================ +// STANDALONE UTILS REPLACEMENT +// Insert this after your standard #include statements +// ================================================================ + +#ifndef WARP_SIZE +#define WARP_SIZE 32 +#endif + +// Used in batched_moe_align_block_size +template +__host__ __device__ inline T round_to_next_multiple_of(T x, T y) { + return ((x + y - 1) / y) * y; +} + +// Minimal Dispatch Macros to avoid compiling full utils +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Int: { \ + using scalar_t = int32_t; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::Long: { \ + using scalar_t = int64_t; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + TORCH_CHECK(false, #NAME " not implemented for ", toString(TYPE)); \ + } + +#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + TORCH_CHECK(false, #NAME " not implemented for ", toString(TYPE)); \ + } +// ================================================================ #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) @@ -19,21 +71,21 @@ namespace batched_moe_align_block_size { static constexpr int32_t num_threads = 1024; static constexpr int32_t num_blocks = 1; __global__ void batched_moe_align_block_size_kernel( - int32_t const num_batches, - int32_t const max_tokens_per_batch, - int32_t const block_size, - int32_t const* __restrict__ batch_num_tokens, - int32_t* __restrict__ sorted_ids, - int32_t* __restrict__ block_ids, + int32_t const num_batches, int32_t const max_tokens_per_batch, + int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, int32_t* __restrict__ num_tokens_post_pad) { // TODO(varun): This is a naive implementation. Could be optimized. size_t const batch_id = threadIdx.x; size_t const stride = blockDim.x * gridDim.x; - int32_t const num_blocks_per_batch = CEILDIV(max_tokens_per_batch, block_size); - int32_t const sorted_ids_size = num_blocks_per_batch * num_batches * block_size; + int32_t const num_blocks_per_batch = + CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = + num_blocks_per_batch * num_batches * block_size; int32_t const block_ids_size = sorted_ids_size / block_size; - int32_t const SENTINEL = num_batches * max_tokens_per_batch; // To denote invalid entries. + int32_t const SENTINEL = + num_batches * max_tokens_per_batch; // To denote invalid entries. // Intialize sorted_ids for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { sorted_ids[i] = SENTINEL; @@ -47,7 +99,8 @@ __global__ void batched_moe_align_block_size_kernel( if (batch_id < num_batches) { b_num_tokens = batch_num_tokens[batch_id]; } - int32_t const ceil_b_num_tokens = CEILDIV(b_num_tokens, block_size) * block_size; + int32_t const ceil_b_num_tokens = + CEILDIV(b_num_tokens, block_size) * block_size; // Compute prefix sum over token counts per expert using BlockScan = cub::BlockScan; @@ -79,23 +132,13 @@ __global__ void batched_moe_align_block_size_kernel( template __device__ void _moe_align_block_size( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t padded_num_experts, - int32_t experts_per_warp, - int32_t block_size, - size_t numel, - int32_t* __restrict__ cumsum, - int32_t max_num_tokens_padded, - int32_t max_num_m_blocks, - int32_t model_offset, - int32_t inactive_expert_id, - int32_t topk_num, - int32_t* token_mask, - bool has_expert_map) { + int32_t* __restrict__ expert_map, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id, + int32_t topk_num, int32_t* token_mask, bool has_expert_map) { extern __shared__ int32_t shared_counts[]; // Compute input buffer offsets. Typically these will all be 0, except when @@ -108,7 +151,8 @@ __device__ void _moe_align_block_size( // This is safe since the current kernel does not use sorted_token_ids. if (blockIdx.x % 2) { // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += blockDim.x) { sorted_token_ids[sorted_token_ids_offset + it] = numel; } return; @@ -141,7 +185,8 @@ __device__ void _moe_align_block_size( int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], mask); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], + mask); } __syncthreads(); @@ -172,13 +217,15 @@ __device__ void _moe_align_block_size( __syncthreads(); if (threadIdx.x < num_experts) { - for (int i = cumsum[cumsum_offset + threadIdx.x]; i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + for (int i = cumsum[cumsum_offset + threadIdx.x]; + i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; } } // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + const size_t fill_start_idx = + cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { expert_ids[expert_ids_offset + i] = inactive_expert_id; } @@ -187,20 +234,12 @@ __device__ void _moe_align_block_size( template __device__ void _moe_align_block_size_small_batch_expert( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t block_size, - size_t numel, - int32_t max_num_tokens_padded, - int32_t max_num_m_blocks, - int32_t inactive_expert_id, - int32_t model_offset, - int32_t topk_num, - int32_t* token_mask, - bool has_expert_map) { + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks, + int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num, + int32_t* token_mask, bool has_expert_map) { // Compute input buffer offsets. Typically these will all be 0, except when // using Multi LoRA. int sorted_token_ids_offset = max_num_tokens_padded * model_offset; @@ -212,7 +251,8 @@ __device__ void _moe_align_block_size_small_batch_expert( // synchronization easier. if (threadIdx.x < fill_threads) { // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) { + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { sorted_token_ids[sorted_token_ids_offset + it] = numel; } // Three __syncthreads() corresponding to the other threads @@ -249,7 +289,8 @@ __device__ void _moe_align_block_size_small_batch_expert( if (tid < num_experts) { tokens_cnts[tid] = 0; for (int i = 1; i <= stride; ++i) { - tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; } } @@ -258,9 +299,13 @@ __device__ void _moe_align_block_size_small_batch_expert( if (tid == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size; + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; } - total_tokens_post_pad[model_offset] = static_cast(cumsum[num_experts]); + total_tokens_post_pad[model_offset] = + static_cast(cumsum[num_experts]); } __syncthreads(); @@ -284,7 +329,8 @@ __device__ void _moe_align_block_size_small_batch_expert( // filter invalid expert if (expert_id == -1) continue; } - int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; if (token_mask == nullptr || token_mask[i / topk_num]) { sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; @@ -296,16 +342,10 @@ __device__ void _moe_align_block_size_small_batch_expert( template __device__ void _count_and_sort_expert_tokens( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, - size_t numel, - int32_t num_experts, - int32_t max_num_tokens_padded, - int32_t* __restrict__ token_mask, - int32_t model_offset, - int32_t topk_num, - bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask, + int32_t model_offset, int32_t topk_num, bool has_expert_map) { const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.y; @@ -322,8 +362,10 @@ __device__ void _count_and_sort_expert_tokens( } if (token_mask == nullptr || token_mask[i / topk_num]) { - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); - sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i; + int32_t rank_post_pad = atomicAdd( + &cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = + i; } } } @@ -331,63 +373,28 @@ __device__ void _count_and_sort_expert_tokens( template __global__ void moe_align_block_size_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t padded_num_experts, - int32_t experts_per_warp, - int32_t block_size, - size_t numel, - int32_t* __restrict__ cumsum, - int32_t max_num_tokens_padded, - int32_t topk_num, - bool has_expert_map) { + int32_t* __restrict__ expert_map, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t topk_num, bool has_expert_map) { _moe_align_block_size( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - numel, - cumsum, - max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), - 0, - 0, - topk_num, - nullptr, - has_expert_map); + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), + 0, 0, topk_num, nullptr, has_expert_map); } template __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, - size_t numel, - int32_t num_experts, - int32_t max_num_tokens_padded, - int32_t topk_num, - bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { _count_and_sort_expert_tokens( - topk_ids, - sorted_token_ids, - cumsum_buffer, - expert_map, - numel, - num_experts, - max_num_tokens_padded, - nullptr, - 0, - topk_num, - has_expert_map); + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map); } template @@ -409,56 +416,29 @@ __global__ void moe_sum_kernel( template __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t block_size, - size_t numel, - int32_t max_num_tokens_padded, - int32_t topk_num, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { _moe_align_block_size_small_batch_expert( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - block_size, - numel, - max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), - 0, - 0, - topk_num, - nullptr, + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, has_expert_map); } template __global__ void moe_lora_align_block_size_kernel( - scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ token_lora_mapping, - int64_t block_size, - int32_t* __restrict__ expert_map, - int num_experts, - int max_loras, - size_t numel, - int max_num_tokens_padded, - int max_num_m_blocks, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int32_t topk_num, - int32_t* total_tokens_post_pad, - int32_t* adapter_enabled, - int32_t* __restrict__ cumsum, - int32_t experts_per_warp, - int32_t padded_num_experts, - int32_t* lora_ids, - int32_t* __restrict__ token_mask, - bool has_expert_map) { + scalar_t* __restrict__ topk_ids, int32_t* __restrict__ seg_indptr, int32_t* __restrict__ req_to_lora, + int num_reqs, int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, int32_t experts_per_warp, + int32_t padded_num_experts, int32_t* lora_ids, + int32_t* __restrict__ token_mask, bool has_expert_map) { int lora_idx = blockIdx.x / 2; int lora_id = lora_ids[lora_idx]; if (lora_id == -1 || adapter_enabled[lora_id] == 0) { @@ -467,50 +447,51 @@ __global__ void moe_lora_align_block_size_kernel( // Populate the token_mask based on the token-LoRA mapping int num_tokens = numel / topk_num; + int lora_offset = lora_id * num_tokens; + + // 1. Parallel Clear (Reset mask to 0) + // All threads help clear the mask for this adapter + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + token_mask[lora_offset + i] = 0; + } + + // Initialize output counter if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; + total_tokens_post_pad[lora_id] = 0; + } - for (int i = 0; i < num_tokens; i++) { - token_mask[(lora_id * num_tokens) + i] = (int)token_lora_mapping[i] == lora_id; - } + __syncthreads(); + + // 2. Segment-based Fill + // Iterate over requests. If a request matches this LoRA, fill its range. + for (int r = 0; r < num_reqs; ++r) { + if (req_to_lora[r] == lora_id) { + int start = seg_indptr[r]; + int end = seg_indptr[r+1]; + + // Parallel Fill: All threads help mark this segment as "1" + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + token_mask[lora_offset + i] = 1; + } + } } __syncthreads(); _moe_align_block_size( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - numel, - cumsum, - max_num_tokens_padded, - max_num_m_blocks, - lora_id, - -1, - topk_num, - &token_mask[(lora_id * num_tokens)], - has_expert_map); + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num, + &token_mask[(lora_id * num_tokens)], has_expert_map); } template __global__ void lora_count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, - size_t numel, - int32_t num_experts, - int32_t max_num_tokens_padded, - int32_t topk_num, - int32_t* token_mask, - int32_t* lora_ids, - bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask, + int32_t* lora_ids, bool has_expert_map) { int lora_idx = blockIdx.x; int lora_id = lora_ids[lora_idx]; if (lora_id == -1) { @@ -520,38 +501,20 @@ __global__ void lora_count_and_sort_expert_tokens_kernel( int num_tokens = numel / topk_num; _count_and_sort_expert_tokens( - topk_ids, - sorted_token_ids, - cumsum_buffer, - expert_map, - numel, - num_experts, - max_num_tokens_padded, - &token_mask[(lora_id * num_tokens)], - lora_id, - topk_num, - has_expert_map); + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id, + topk_num, has_expert_map); } template __global__ void moe_lora_align_block_size_small_batch_expert_kernel( - scalar_t* __restrict__ topk_ids, - int32_t* token_lora_mapping, - int64_t block_size, - int32_t* __restrict__ expert_map, - int num_experts, - int max_loras, - size_t numel, - int max_num_tokens_padded, - int max_num_m_blocks, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int topk_num, - int32_t* total_tokens_post_pad, - int32_t* adapter_enabled, - int32_t* lora_ids, - int32_t* token_mask, - bool has_expert_map) { + scalar_t* __restrict__ topk_ids, int32_t* __restrict__ seg_indptr, int32_t* __restrict__ req_to_lora, + int num_reqs, int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids, + int32_t* token_mask, bool has_expert_map) { int lora_idx = blockIdx.x; int lora_id = lora_ids[lora_idx]; if (lora_id == -1 || adapter_enabled[lora_id] == 0) { @@ -559,31 +522,41 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( } int num_tokens = numel / topk_num; + int lora_offset = lora_id * num_tokens; + + // 1. Parallel Clear (Reset mask to 0) + // All threads help clear the mask for this adapter + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + token_mask[lora_offset + i] = 0; + } + + // Initialize output counter if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; + total_tokens_post_pad[lora_id] = 0; + } - for (int i = 0; i < num_tokens; i++) { - token_mask[(lora_id * num_tokens) + i] = (int)token_lora_mapping[i] == lora_id; - } + __syncthreads(); + + // 2. Segment-based Fill + // Iterate over requests. If a request matches this LoRA, fill its range. + for (int r = 0; r < num_reqs; ++r) { + if (req_to_lora[r] == lora_id) { + int start = seg_indptr[r]; + int end = seg_indptr[r+1]; + + // Parallel Fill: All threads help mark this segment as "1" + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + token_mask[lora_offset + i] = 1; + } + } } __syncthreads(); _moe_align_block_size_small_batch_expert( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - block_size, - numel, - max_num_tokens_padded, - max_num_m_blocks, - -1, - lora_id, - topk_num, - &token_mask[(lora_id * num_tokens)], + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks, + -1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)], has_expert_map); } @@ -591,24 +564,24 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( // taken from // https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc -void moe_align_block_size( - torch::Tensor topk_ids, - int64_t num_experts, - int64_t block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad, - std::optional maybe_expert_map) { +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; int experts_per_warp = WARP_SIZE; int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); - auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); bool has_expert_map = maybe_expert_map.has_value(); torch::Tensor expert_map; if (has_expert_map) { @@ -617,89 +590,86 @@ void moe_align_block_size( expert_map = torch::empty({0}, options_int); } - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `cumsum` tensors - bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t threads = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - auto small_batch_expert_kernel = moe::moe_align_block_size_small_batch_expert_kernel; - small_batch_expert_kernel<<<1, fill_threads + threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), - num_experts, - block_size, - topk_ids.numel(), - sorted_token_ids.size(0), - topk_ids.size(1), - has_expert_map); - } else { - torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int); - auto align_kernel = moe::moe_align_block_size_kernel; - - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); - size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); - - // launch two threadblocks - // blockIdx.x == 0: counting experts and aligning - // blockIdx.x == 1: filling sorted_token_ids - align_kernel<<<2, threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - topk_ids.numel(), - cumsum_buffer.data_ptr(), - sorted_token_ids.size(0), - topk_ids.size(1), - has_expert_map); - - const int block_threads = std::min(256, (int)threads); - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - dim3 gridDims(1, actual_blocks); - - auto sort_kernel = moe::count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), - expert_map.data_ptr(), - topk_ids.numel(), - num_experts, - sorted_token_ids.size(0), - topk_ids.size(1), - has_expert_map); - } - }); + DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = + ((threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + auto small_batch_expert_kernel = + moe::moe_align_block_size_small_batch_expert_kernel< + scalar_t, fill_threads>; + small_batch_expert_kernel<<<1, fill_threads + threads, + shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, block_size, + topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1), + has_expert_map); + } else { + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); + auto align_kernel = moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = + num_warps * experts_per_warp * sizeof(int32_t); + + // launch two threadblocks + // blockIdx.x == 0: counting experts and aligning + // blockIdx.x == 1: filling sorted_token_ids + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, padded_num_experts, + experts_per_warp, block_size, topk_ids.numel(), + cumsum_buffer.data_ptr(), sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + dim3 gridDims(1, actual_blocks); + + auto sort_kernel = + moe::count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), expert_map.data_ptr(), + topk_ids.numel(), num_experts, sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); + } + }); } -void batched_moe_align_block_size( - int64_t max_tokens_per_batch, - int64_t block_size, - torch::Tensor const& batch_num_tokens, - torch::Tensor sorted_ids, - torch::Tensor batch_ids, - torch::Tensor num_tokens_post_pad) { +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { namespace batched_kernel = moe::batched_moe_align_block_size; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int32_t const B = batch_num_tokens.size(0); - int32_t const num_blocks_per_batch = round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks_per_batch = + round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; int32_t const num_blocks = num_blocks_per_batch * B; int64_t const sorted_ids_size = num_blocks * block_size; @@ -708,20 +678,15 @@ void batched_moe_align_block_size( TORCH_CHECK(num_tokens_post_pad.size(0) == 1); TORCH_CHECK(B <= batched_kernel::num_threads); - batched_kernel:: - batched_moe_align_block_size_kernel<<>>( - B, - max_tokens_per_batch, - block_size, - batch_num_tokens.data_ptr(), - sorted_ids.data_ptr(), - batch_ids.data_ptr(), - num_tokens_post_pad.data_ptr()); + batched_kernel::batched_moe_align_block_size_kernel<<< + batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( + B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr()); } -void moe_sum( - torch::Tensor& input, // [num_tokens, topk, hidden_size] - torch::Tensor& output) // [num_tokens, hidden_size] +void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); const auto num_tokens = output.numel() / hidden_size; @@ -735,22 +700,25 @@ void moe_sum( switch (topk) { case 2: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel - <<>>(output.data_ptr(), input.data_ptr(), hidden_size); + moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); }); break; case 3: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel - <<>>(output.data_ptr(), input.data_ptr(), hidden_size); + moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); }); break; case 4: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel - <<>>(output.data_ptr(), input.data_ptr(), hidden_size); + moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); }); break; @@ -761,35 +729,33 @@ void moe_sum( } void moe_lora_align_block_size( - torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, - int64_t block_size, - int64_t max_loras, - int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, - torch::Tensor adapter_enabled, - torch::Tensor lora_ids, - std::optional maybe_expert_map) { + torch::Tensor topk_ids, torch::Tensor seg_indptr, torch::Tensor req_to_lora, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids, std::optional maybe_expert_map) { const int topk_num = topk_ids.size(1); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); int device_max_shared_mem; auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); - auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor token_mask = torch::empty({max_loras * topk_ids.size(0)}, options_int); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor token_mask = + torch::empty({max_loras * topk_ids.size(0)}, options_int); bool has_expert_map = maybe_expert_map.has_value(); torch::Tensor expert_map; if (has_expert_map) { @@ -797,107 +763,95 @@ void moe_lora_align_block_size( } else { expert_map = torch::empty({0}, options_int); } - - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { - bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t num_thread = max((int32_t)num_experts, 128); - const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); - if (shared_mem > device_max_shared_mem) { - TORCH_CHECK(false, "Shared memory usage exceeds device limit."); - } - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - - dim3 blockDim(num_thread + fill_threads); - auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); - kernel<<>>( - topk_ids.data_ptr(), - token_lora_mapping.data_ptr(), - block_size, - expert_map.data_ptr(), - num_experts, - max_loras, - topk_ids.numel(), - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), - topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), - lora_ids.data_ptr(), - token_mask.data_ptr(), - has_expert_map); - } else { - int num_thread = 1024; - dim3 blockDim(num_thread); - size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); - - size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); - - // cumsum buffer - torch::Tensor cumsum = torch::zeros({max_loras * (num_experts + 1)}, options_int); - - auto align_kernel = moe::moe_lora_align_block_size_kernel; - - // launch two threadblocks for each lora - // blockIdx.x % 2 == 0: counting experts and aligning - // blockIdx.x % 2 == 1: filling sorted_token_ids - align_kernel<<>>( - topk_ids.data_ptr(), - token_lora_mapping.data_ptr(), - block_size, - expert_map.data_ptr(), - num_experts, - max_loras, - topk_ids.numel(), - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), - topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), - cumsum.data_ptr(), - WARP_SIZE, - padded_num_experts, - lora_ids.data_ptr(), - token_mask.data_ptr(), - has_expert_map); - - const int block_threads = std::min(256, (int)num_thread); - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - dim3 gridDims(max_loras, actual_blocks); - auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum.data_ptr(), - expert_map.data_ptr(), - topk_ids.numel(), - num_experts, - max_num_tokens_padded, - topk_num, - token_mask.data_ptr(), - lora_ids.data_ptr(), - has_expert_map); - } - }); + int num_reqs = seg_indptr.size(0) - 1; + + DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = max((int32_t)num_experts, 128); + const int32_t shared_mem = + (num_thread + 1) * num_experts * sizeof(int32_t) + + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, "Shared memory usage exceeds device limit."); + } + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = + moe::moe_lora_align_block_size_small_batch_expert_kernel< + scalar_t, fill_threads>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); + kernel<<>>( + topk_ids.data_ptr(), + seg_indptr.data_ptr(), req_to_lora.data_ptr(), num_reqs, block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + // cumsum buffer + torch::Tensor cumsum = + torch::zeros({max_loras * (num_experts + 1)}, options_int); + + auto align_kernel = + moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + align_kernel<<>>( + topk_ids.data_ptr(), + seg_indptr.data_ptr(), req_to_lora.data_ptr(), num_reqs, block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), cumsum.data_ptr(), + WARP_SIZE, padded_num_experts, lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = + moe::lora_count_and_sort_expert_tokens_kernel; + + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), cumsum.data_ptr(), + expert_map.data_ptr(), topk_ids.numel(), num_experts, + max_num_tokens_padded, topk_num, token_mask.data_ptr(), + lora_ids.data_ptr(), has_expert_map); + } + }); } // TODO: Jonahbernard: remove this later #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("moe_lora_align_block_size", &moe_lora_align_block_size, "MoE LoRA Align Block Size"); + m.def("moe_lora_align_block_size", &moe_lora_align_block_size, + "MoE LoRA Align Block Size"); } From 39ebafd582efe1c355f35f3496fd6c4203bd9e79 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 6 Feb 2026 12:12:06 -0500 Subject: [PATCH 090/150] finalize fixes --- python/sglang/srt/lora/layers.py | 2 +- python/sglang/srt/lora/lora_moe_runners.py | 28 +- sgl-kernel/csrc/common_extension.cc | 8 + .../csrc/moe/moe_lora_align_sum_kernel.cu | 7 - sgl-kernel/include/sgl_kernel_ops.h | 16 ++ sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/moe.py | 34 +++ .../manual/lora/test_fused_moe_lora_kernel.py | 73 ++++-- test/manual/lora/test_lora_moe_end_to_end.py | 165 ------------ test/manual/lora/test_lora_moe_runner.py | 239 ++++++++---------- test/manual/lora/test_moe_lora_align_sum.py | 103 ++++++-- .../lora/test_lora_hf_sgl_logprob_diff.py | 28 +- 12 files changed, 336 insertions(+), 368 deletions(-) delete mode 100644 test/manual/lora/test_lora_moe_end_to_end.py diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 16eb2c97976e..05d29483bcdc 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -670,7 +670,7 @@ def _get_lora_info( gate_up_lora_b_weights=self.gate_up_lora_b_weights, down_lora_a_weights=self.down_lora_a_weights, down_lora_b_weights=self.down_lora_b_weights, - seg_ind = batch_info.seg_indptr, + seg_indptr = batch_info.seg_indptr, req_to_lora = batch_info.weight_indices, lora_ranks=lora_ranks, adapter_enabled=adapter_enabled, diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 5b1cd1ba3754..e060470b39f4 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -39,11 +39,12 @@ from sglang.srt.utils import is_cuda, is_hip + _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda or _is_hip: - from sgl_kernel import gelu_and_mul, silu_and_mul + from sgl_kernel import gelu_and_mul, silu_and_mul, moe_lora_align_block_size @dataclass @@ -62,8 +63,11 @@ class LoRAInfo: ) # [num_loras, num_experts, max_rank, intermediate_dim] down_lora_b_weights: torch.Tensor # [num_loras, num_experts, hidden_dim, max_rank] - # Per-token LoRA adapter ID [num_tokens] - token_lora_mapping: torch.Tensor + # Indice pointers of each segment in shape (num_segments + 1, ) + seg_indptr: torch.Tensor + + # The index of lora adapter used by each segment, in shape (num_segments,) + req_to_lora: torch.Tensor # LoRA config per adapter lora_ranks: torch.Tensor # [num_loras] @@ -212,6 +216,7 @@ def run( dtype=hidden_states.dtype, ) + invoke_fused_moe_kernel( hidden_states, w13, @@ -274,12 +279,13 @@ def run( ) # Get token-to-LoRA mapping from lora_info - token_lora_mapping = lora_info.token_lora_mapping lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) - moe_lora_ops.moe_lora_align_block_size( + + moe_lora_align_block_size( topk_ids, - token_lora_mapping, + lora_info.seg_indptr, + lora_info.req_to_lora, int(lora_info.num_experts), int(block_size_m), int(max_loras), @@ -362,6 +368,7 @@ def run( else: out_hidden_states = torch.empty_like(hidden_states) + invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -405,6 +412,7 @@ def run( num_tokens_post_padded_lora=num_tokens_post_padded_lora, ) + # ============================================================ # Stage 4: Final reduction (sum across top_k) # ============================================================ @@ -475,6 +483,7 @@ def _add_lora_gate_up_delta( M, top_k, gate_up_dim = intermediate_cache.shape + # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return @@ -486,9 +495,8 @@ def _add_lora_gate_up_delta( max_loras = len(lora_info.lora_ranks) - lora_ids = torch.arange( - max_loras, dtype=torch.int32, device=hidden_states.device - ) + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=hidden_states.device) + fused_moe_lora( output=intermediate_cache, @@ -540,6 +548,7 @@ def _add_lora_down_delta( M, top_k, hidden_dim = intermediate_cache.shape + # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return @@ -554,6 +563,7 @@ def _add_lora_down_delta( device = intermediate_cache.device lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 463b924fa79b..681ae8d37c7e 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -227,6 +227,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + m.def( + "moe_lora_align_block_size(Tensor topk_ids, Tensor seg_indptr, Tensor req_to_lora, " + "int num_experts, int block_size, int max_loras, int max_num_tokens_padded, " + "int max_num_m_blocks, Tensor! sorted_token_ids, Tensor! expert_ids, " + "Tensor! num_tokens_post_pad, Tensor adapter_enabled, Tensor lora_ids, " + "Tensor? maybe_expert_map) -> ()"); + m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); + m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float " "moe_softcapping, Tensor? correction_bias) -> ()"); diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu index e350f9108448..b88a8f5c7467 100644 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -848,10 +848,3 @@ void moe_lora_align_block_size( }); } -// TODO: Jonahbernard: remove this later -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("moe_lora_align_block_size", &moe_lora_align_block_size, - "MoE LoRA Align Block Size"); -} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 5e3cf24f9036..29280066e15c 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -324,6 +324,22 @@ void topk_sigmoid( bool renormalize, const c10::optional& correction_bias); +void moe_lora_align_block_size( + torch::Tensor topk_ids, + torch::Tensor seg_indptr, + torch::Tensor req_to_lora, + int64_t num_experts, + int64_t block_size, + int64_t max_loras, + int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor adapter_enabled, + torch::Tensor lora_ids, + std::optional maybe_expert_map); + void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor); void moe_sum(torch::Tensor& input, torch::Tensor& output); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 1b97ef94f02b..df8dbd3e41e3 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -94,6 +94,7 @@ kimi_k2_moe_fused_gate, moe_align_block_size, moe_fused_gate, + moe_lora_align_block_size, moe_sum, moe_sum_reduce, prepare_moe_input, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index d85e4b602751..31ee02381e89 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -25,6 +25,40 @@ def moe_align_block_size( ) +def moe_lora_align_block_size( + topk_ids, + seg_indptr, + req_to_lora, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + adapter_enabled, + lora_ids, + maybe_expert_map=None, +): + torch.ops.sgl_kernel.moe_lora_align_block_size.default( + topk_ids, + seg_indptr, + req_to_lora, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + adapter_enabled, + lora_ids, + maybe_expert_map, + ) + + def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/manual/lora/test_fused_moe_lora_kernel.py index 72c4356e8431..e876aca6e583 100644 --- a/test/manual/lora/test_fused_moe_lora_kernel.py +++ b/test/manual/lora/test_fused_moe_lora_kernel.py @@ -1,8 +1,7 @@ # adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_fused_moe_lora_kernel.py - +import sys import os import random - import pytest import torch @@ -14,9 +13,14 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.lora.triton_ops import fused_moe_lora -from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size from sglang.srt.utils import set_random_seed +# ============================================================================== +# IMPORT PREBUILT KERNEL +# ============================================================================== +from sgl_kernel import moe_lora_align_block_size +# ============================================================================== + def round_up(x, base): return ((x + base - 1) // base) * base @@ -38,8 +42,9 @@ def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): max_loras (int): Total number of available LoRA modules. Returns: - torch.Tensor: 1D tensor of shape [num_tokens], where each value - is the LoRA index assigned to that token. + token_lora_mapping (torch.Tensor): 1D tensor of shape [num_tokens] + seg_indptr (torch.Tensor): 1D tensor of shape [num_sequences + 1] + req_to_lora (torch.Tensor): 1D tensor of shape [num_sequences] """ assert num_sequences > 0 and max_loras > 0 assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" @@ -49,6 +54,8 @@ def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): remainder = num_tokens % num_sequences token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + seg_indptr = [0] + req_to_lora = [] start = 0 for seq_idx in range(num_sequences): @@ -61,9 +68,15 @@ def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): # Assign the same LoRA ID to all tokens in this sequence token_lora_mapping[start:end] = lora_id + seg_indptr.append(end) + req_to_lora.append(lora_id) + start = end - return token_lora_mapping + seg_indptr = torch.tensor(seg_indptr, dtype=torch.int32) + req_to_lora = torch.tensor(req_to_lora, dtype=torch.int32) + + return token_lora_mapping, seg_indptr, req_to_lora def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): @@ -108,14 +121,17 @@ def sample_data( topk_ids, topk_weights = assign_experts_to_tokens( num_tokens, num_experts, top_k_num ) - token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) - return topk_ids, topk_weights, token_lora_mapping + token_lora_mapping, seg_indptr, req_to_lora = assign_loras_to_tokens( + num_tokens, num_sequences, max_loras + ) + return topk_ids, topk_weights, token_lora_mapping, seg_indptr, req_to_lora def use_fused_moe_lora_kernel( topk_ids, topk_weights, - token_lora_mapping, + seg_indptr, + req_to_lora, max_lora_rank, top_k_num, lora_a_stacked, @@ -132,20 +148,25 @@ def use_fused_moe_lora_kernel( max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + # Important: Ensure output tensors are on the same device as inputs + device = topk_ids.device + # init output tensors sorted_token_ids = torch.empty( (max_loras * max_num_tokens_padded,), dtype=torch.int32, + device=device ) - expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) - num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) - adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) + expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32, device=device) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32, device=device) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device=device) # call kernel - moe_align_block_size( + moe_lora_align_block_size( topk_ids, - token_lora_mapping, + seg_indptr, + req_to_lora, num_experts, block_size, max_loras, @@ -156,6 +177,7 @@ def use_fused_moe_lora_kernel( num_tokens_post_padded, adapter_enabled, lora_ids, + None # maybe_expert_map ) config = { @@ -260,10 +282,17 @@ def test_fused_moe_lora_kernel( # the number of randomly generated sentences. num_sequences = 10 # generate data - topk_ids, topk_weights, token_lora_mapping = sample_data( + topk_ids, topk_weights, token_lora_mapping, seg_indptr, req_to_lora = sample_data( num_tokens, num_sequences, max_loras, num_experts, top_k_num ) + # Ensure generated data is on the correct device + topk_ids = topk_ids.to(device) + topk_weights = topk_weights.to(device) + token_lora_mapping = token_lora_mapping.to(device) + seg_indptr = seg_indptr.to(device) + req_to_lora = req_to_lora.to(device) + # init lora weights lora_a_stacked = [ torch.rand( @@ -274,6 +303,7 @@ def test_fused_moe_lora_kernel( K, ), dtype=dtype, + device=device ) ] lora_b_stacked = [ @@ -285,6 +315,7 @@ def test_fused_moe_lora_kernel( max_lora_rank, ), dtype=dtype, + device=device ) ] hidden_states = torch.rand( @@ -293,14 +324,17 @@ def test_fused_moe_lora_kernel( K, ), dtype=dtype, + device=device ) # fused_moe_lora_kernel output - output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype, device=device) + use_fused_moe_lora_kernel( topk_ids, topk_weights, - token_lora_mapping, + seg_indptr, + req_to_lora, max_lora_rank, top_k_num, lora_a_stacked, @@ -322,3 +356,6 @@ def test_fused_moe_lora_kernel( ) torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/manual/lora/test_lora_moe_end_to_end.py b/test/manual/lora/test_lora_moe_end_to_end.py deleted file mode 100644 index 3f3954571a17..000000000000 --- a/test/manual/lora/test_lora_moe_end_to_end.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -End-to-end test for LoRA MoE model inference using SGLang. - -This script loads a LoRA MoE model using SGLang runner, runs inference on a test dataset, -and compares outputs against gold labels to validate correctness. -""" - -import json -import os -import sys -from typing import List, Dict, Any -from urllib.request import urlopen - -import torch -from sglang.test.runners import SRTRunner - -# Configuration - set your model and LoRA paths here -MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" # Your LoRA MoE model path -LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" # REQUIRED: Your LoRA adapter path -TEST_DATA_URL = "https://huggingface.co/jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B/blob/main/training_dataset.json" # URL to test data JSON file - - -def load_test_dataset(test_data_url: str) -> List[Dict[str, Any]]: - """Load test dataset from JSON URL.""" - try: - with urlopen(test_data_url) as response: - test_dataset = json.loads(response.read().decode('utf-8')) - except Exception as e: - raise RuntimeError(f"Failed to load test data from URL {test_data_url}: {e}") - - return test_dataset - - -def run_lora_moe_inference_test(): - """Run end-to-end test for LoRA MoE model inference using SGLang.""" - - print("=== LoRA MoE End-to-End Test (SGLang) ===\n") - - print(f"Model: {MODEL_PATH}") - print(f"LoRA Path: {LORA_PATH}") - print(f"Test Data URL: {TEST_DATA_URL}") - print() - - # Load test dataset - try: - test_dataset = load_test_dataset(TEST_DATA_URL) - print(f"Loaded {len(test_dataset)} test cases from {TEST_DATA_URL}") - except Exception as e: - print(f"Error loading test data: {e}") - return False - - # Initialize results tracking - results = [] - total_tests = len(test_dataset) - correct_predictions = 0 - - try: - # Initialize SGLang runner - print("Initializing SGLang runner...") - with SRTRunner( - model_path=MODEL_PATH, - torch_dtype=torch.float32, - model_type="generation", - trust_remote_code=True, - lora_paths=[LORA_PATH], - max_loras_per_batch=1, - ) as runner: - print("SGLang runner initialized successfully. Running inference tests...\n") - - # Run inference on each test case - for i, test_case in enumerate(test_dataset, 1): - instruction = test_case["instruction"] - expected_output = test_case["output"] - test_type = test_case["type"] - - print(f"Test {i}/{total_tests}: {test_type}") - print(f"Instruction: {instruction}") - print(f"Expected: '{expected_output}'") - - try: - # Run inference using SGLang runner - model_output = runner.forward( - prompts=[instruction], - max_new_tokens=50, # Adjust as needed for your model - lora_paths=[LORA_PATH], - ) - - # Extract the generated text - generated_output = model_output.output_strs[0] - print(f"Generated: '{generated_output}'") - - # Compare with expected output (exact match for simplicity) - is_correct = generated_output.strip() == expected_output.strip() - - if is_correct: - correct_predictions += 1 - print("✓ PASS") - else: - print("✗ FAIL") - - # Store result - results.append({ - "test_id": i, - "type": test_type, - "instruction": instruction, - "expected": expected_output, - "generated": generated_output, - "correct": is_correct - }) - - except Exception as e: - print(f"✗ ERROR: {e}") - results.append({ - "test_id": i, - "type": test_type, - "instruction": instruction, - "expected": expected_output, - "generated": f"ERROR: {e}", - "correct": False - }) - - print("-" * 50) - - # Print final statistics - accuracy = correct_predictions / total_tests * 100 - - print("\n=== Test Results ===") - print(f"Total tests: {total_tests}") - print(f"Correct predictions: {correct_predictions}") - print(".2f") - print() - - # Print detailed results - print("Detailed Results:") - for result in results: - status = "PASS" if result["correct"] else "FAIL" - print(f"Test {result['test_id']}: {result['type']} - {status}") - - # Group by type - type_stats = {} - for result in results: - test_type = result["type"] - if test_type not in type_stats: - type_stats[test_type] = {"total": 0, "correct": 0} - type_stats[test_type]["total"] += 1 - if result["correct"]: - type_stats[test_type]["correct"] += 1 - - print("\nResults by Type:") - for test_type, stats in type_stats.items(): - type_accuracy = stats["correct"] / stats["total"] * 100 - print(f" {test_type}: {stats['correct']}/{stats['total']} ({type_accuracy:.1f}%)") - - return accuracy >= 0 # Always return True for now, adjust threshold as needed - - except Exception as e: - print(f"Test failed with error: {e}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = run_lora_moe_inference_test() - sys.exit(0 if success else 1) diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py index 0942ca029c79..68616d7aa24f 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/manual/lora/test_lora_moe_runner.py @@ -31,28 +31,38 @@ from sglang.srt.utils import set_random_seed -def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): +def generate_request_data(num_tokens: int, num_sequences: int, max_loras: int, device="cuda"): + """ + Generates segment-based request data instead of token-based data. + """ assert num_sequences > 0 and max_loras > 0 assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" - tokens_per_seq = num_tokens // num_sequences - remainder = num_tokens % num_sequences + # 1. Generate random segment lengths + remaining = num_tokens + seg_lens = [] + for _ in range(num_sequences - 1): + # Ensure at least 1 token per sequence + max_len = remaining - (num_sequences - len(seg_lens)) + 1 + length = random.randint(1, min(max_len, num_tokens // num_sequences * 2)) + seg_lens.append(length) + remaining -= length + seg_lens.append(remaining) # Last segment gets the rest - token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + # 2. Build seg_indptr [0, len1, len1+len2, ...] + seg_indptr = torch.cumsum(torch.tensor([0] + seg_lens, dtype=torch.int32, device=device), dim=0, dtype=torch.int32) - start = 0 - for seq_idx in range(num_sequences): - end = start + tokens_per_seq + (1 if seq_idx < remainder else 0) - lora_id = random.randint(0, max_loras - 1) - token_lora_mapping[start:end] = lora_id - start = end + # 3. Assign one LoRA ID per Request + req_to_lora = torch.randint(0, max_loras, (num_sequences,), dtype=torch.int32, device=device) - return token_lora_mapping + # 4. Create dense mapping for the Naive verification function + # (Expand req_to_lora based on seg_lens) + token_lora_mapping = torch.repeat_interleave(req_to_lora, torch.tensor(seg_lens, device=device)) + return seg_indptr, req_to_lora, token_lora_mapping -def assign_experts_to_tokens( - num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32 -): + +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32): assert top_k_num <= num_experts, "top_k_num must be <= num_experts" expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) @@ -66,76 +76,38 @@ def assign_experts_to_tokens( return expert_indices, expert_weights -def sample_data( - num_tokens: int, - num_sequences: int, - max_loras: int, - num_experts: int, - top_k_num: int, - dtype=torch.float32, -): - topk_ids, topk_weights = assign_experts_to_tokens( - num_tokens, num_experts, top_k_num, dtype - ) - token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) - return topk_ids, topk_weights, token_lora_mapping - - -def create_lora_info( - token_lora_mapping, - topk_ids, - max_loras, - num_experts, - max_lora_rank, - hidden_dim, - intermediate_dim, - gate_up_dim, - dtype, - device, -): +def sample_data(num_tokens: int, num_sequences: int, max_loras: int, num_experts: int, top_k_num: int, dtype=torch.float32, device="cuda"): + topk_ids, topk_weights = assign_experts_to_tokens(num_tokens, num_experts, top_k_num, dtype) + seg_indptr, req_to_lora, token_lora_mapping = generate_request_data(num_tokens, num_sequences, max_loras, device) + return topk_ids, topk_weights, seg_indptr, req_to_lora, token_lora_mapping + + +def create_lora_info(seg_indptr, weight_indices, topk_ids, max_loras, num_experts, max_lora_rank, hidden_dim, intermediate_dim, gate_up_dim, dtype, device): # ------------------------------------------------------------------------- # 1. Deterministic LoRA A Initialization # ------------------------------------------------------------------------- - # We fill A with (1 / input_dim). - # If input is all 1s, A @ x will result in a vector of all 1s. - val_gate_up_a = 1.0 / hidden_dim gate_up_lora_a_weights = torch.full( (max_loras, num_experts, max_lora_rank, hidden_dim), - val_gate_up_a, - dtype=dtype, - device=device, + val_gate_up_a, dtype=dtype, device=device ) val_down_a = 1.0 / intermediate_dim down_lora_a_weights = torch.full( (max_loras, num_experts, max_lora_rank, intermediate_dim), - val_down_a, - dtype=dtype, - device=device, + val_down_a, dtype=dtype, device=device ) # ------------------------------------------------------------------------- - # 2. Deterministic LoRA B Initialization (Expert-Specific) + # 2. Deterministic LoRA B Initialization # ------------------------------------------------------------------------- - # We want the output to be safe but noticeable. - # Let's target a base value of 0.1 per expert index. - # Formula: fill_value = (target / rank) * (expert_id + 1) - - base_target = 0.01 # Small enough to not explode SiLU, big enough to see + base_target = 0.01 - gate_up_lora_b_weights = torch.zeros( - (max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device - ) - down_lora_b_weights = torch.zeros( - (max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device - ) + gate_up_lora_b_weights = torch.zeros((max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device) + down_lora_b_weights = torch.zeros((max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device) for i in range(num_experts): - # Make every expert add a slightly different value so we can check routing - # Expert 0 adds ~0.01, Expert 10 adds ~0.11 - expert_multiplier = i + 1 - + expert_multiplier = (i + 1) fill_val = (base_target * expert_multiplier) / max_lora_rank gate_up_lora_b_weights[:, i, :, :] = fill_val @@ -144,17 +116,21 @@ def create_lora_info( # ------------------------------------------------------------------------- # 3. Setup Metadata # ------------------------------------------------------------------------- - lora_ranks = torch.full( - (max_loras,), max_lora_rank, dtype=torch.int32, device=device - ) - adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) + lora_ranks = torch.full((max_loras,), max_lora_rank, dtype=torch.int32, device=device) + + # Enable all adapters referenced in weight_indices + adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32, device=device) + adapter_enabled.index_fill_(0, weight_indices.long(), 1) return LoRAInfo( gate_up_lora_a_weights=gate_up_lora_a_weights, gate_up_lora_b_weights=gate_up_lora_b_weights, down_lora_a_weights=down_lora_a_weights, down_lora_b_weights=down_lora_b_weights, - token_lora_mapping=token_lora_mapping, + # UPDATED FIELDS + seg_indptr=seg_indptr, + req_to_lora=weight_indices, + lora_ranks=lora_ranks, adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, @@ -162,24 +138,20 @@ def create_lora_info( ) -def torch_naive_moe_with_lora( - hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info -): +def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info, token_lora_mapping): + """ + Naive implementation. Note: We pass 'token_lora_mapping' explicitly because + lora_info no longer contains it, but the naive token-loop logic needs it. + """ num_tokens, hidden_dim = hidden_states.shape top_k = topk_ids.shape[1] num_experts = w13.shape[0] - intermediate_dim = w2.shape[2] - hidden_expanded = ( - hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) - ) + # Expand hidden states for top-k routing + hidden_expanded = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) - gate_up_out = torch.zeros( - num_tokens * top_k, - w13.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + # 1. Gate/Up Projection (Base) + gate_up_out = torch.zeros(num_tokens * top_k, w13.shape[1], dtype=hidden_states.dtype, device=hidden_states.device) for expert_id in range(num_experts): mask = (topk_ids == expert_id).flatten() @@ -191,38 +163,35 @@ def torch_naive_moe_with_lora( gate_up_out = gate_up_out.view(num_tokens, top_k, -1) + # 1.5. LoRA Gate/Up Delta if lora_info.max_lora_rank > 0: for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = lora_info.token_lora_mapping[i] - lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] - lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] - lora_a_result = lora_a @ hidden_states[i] - lora_b_result = lora_b @ lora_a_result - # Using scaling factor of 1.0 since lora_scalings was removed - lora_delta = 1.0 * lora_b_result - gate_up_out[i, k] += lora_delta + lora_id = token_lora_mapping[i] # Use explicit mapping + + # Check if this adapter is enabled/valid + if lora_id < len(lora_info.lora_ranks): + lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] + lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] + lora_a_result = lora_a @ hidden_states[i] + lora_b_result = lora_b @ lora_a_result + gate_up_out[i, k] += lora_b_result + # 2. Activation gate_up_dim = gate_up_out.shape[-1] gate_dim = gate_up_dim // 2 gate = gate_up_out[..., :gate_dim] up = gate_up_out[..., gate_dim:] silu_gate = torch.nn.functional.silu(gate) - intermediate_out = silu_gate * up - down_out = torch.zeros( - num_tokens, - top_k, - hidden_dim, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + # 3. Down Projection (Base) + down_out = torch.zeros(num_tokens, top_k, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device) for expert_id in range(num_experts): - mask = topk_ids == expert_id + mask = (topk_ids == expert_id) if mask.any(): masked_intermediate = intermediate_out[mask] expert_down_result = masked_intermediate @ w2[expert_id].T @@ -230,21 +199,22 @@ def torch_naive_moe_with_lora( if b2 is not None: down_out[mask] += b2[expert_id] + # 3.5. LoRA Down Delta if lora_info.max_lora_rank > 0: for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = lora_info.token_lora_mapping[i] - lora_a = lora_info.down_lora_a_weights[lora_id, expert_id] - lora_b = lora_info.down_lora_b_weights[lora_id, expert_id] - lora_a_result = lora_a @ intermediate_out[i, k] - lora_b_result = lora_b @ lora_a_result - # Using scaling factor of 1.0 since lora_scalings was removed - lora_delta = 1.0 * lora_b_result - down_out[i, k] += lora_delta + lora_id = token_lora_mapping[i] # Use explicit mapping - weighted_out = down_out * topk_weights.unsqueeze(-1) + if lora_id < len(lora_info.lora_ranks): + lora_a = lora_info.down_lora_a_weights[lora_id, expert_id] + lora_b = lora_info.down_lora_b_weights[lora_id, expert_id] + lora_a_result = lora_a @ intermediate_out[i, k] + lora_b_result = lora_b @ lora_a_result + down_out[i, k] += lora_b_result + # 4. Final Reduction + weighted_out = down_out * topk_weights.unsqueeze(-1) final_out = weighted_out.sum(dim=1) return final_out @@ -254,9 +224,7 @@ def torch_naive_moe_with_lora( @pytest.mark.parametrize("top_k_num", [1, 2]) @pytest.mark.parametrize("num_experts", [8, 20]) @pytest.mark.parametrize("max_lora_rank", [8, 16]) -def test_lora_moe_runner_multi_expert( - num_tokens, top_k_num, num_experts, max_lora_rank -): +def test_lora_moe_runner_multi_expert(num_tokens, top_k_num, num_experts, max_lora_rank): # Fixed parameters max_loras = 2 hidden_dim = 512 @@ -269,28 +237,27 @@ def test_lora_moe_runner_multi_expert( torch.set_default_device(device) set_random_seed(seed) - # Distribute tokens across all experts (Random Routing) - topk_ids, topk_weights = assign_experts_to_tokens( - num_tokens, num_experts, top_k_num, dtype - ) - - # Assign LoRAs randomly num_sequences = 4 - token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) + + # Generate Data using the new Request-Based generator + topk_ids, topk_weights, seg_indptr, req_to_lora, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num, dtype, device + ) gate_up_dim = intermediate_dim * 2 - # Initialize ALL experts with non-zero random weights + # Initialize experts w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) * 0.01 w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) * 0.01 b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) * 0.01 b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) * 0.01 - # Set input to random values hidden_states = torch.randn(num_tokens, hidden_dim, dtype=dtype) + # Create LoRA Info using the new fields lora_info = create_lora_info( - token_lora_mapping=token_lora_mapping, + seg_indptr=seg_indptr, + weight_indices=req_to_lora, topk_ids=topk_ids, max_loras=max_loras, num_experts=num_experts, @@ -302,16 +269,14 @@ def test_lora_moe_runner_multi_expert( device=device, ) - # Sort tokens by expert ID for the runner + # Sort tokens for the runner topk_ids_flat = topk_ids.flatten() sorted_indices = torch.argsort(topk_ids_flat) sorted_token_ids = sorted_indices // top_k_num expert_ids = topk_ids_flat[sorted_indices] num_dispatched = num_tokens * top_k_num - num_tokens_post_padded = torch.tensor( - [num_dispatched], dtype=torch.int32, device=device - ) + num_tokens_post_padded = torch.tensor([num_dispatched], dtype=torch.int32, device=device) runner_input = TritonRunnerInput( hidden_states=hidden_states, @@ -359,22 +324,20 @@ def test_lora_moe_runner_multi_expert( class MockServerArgs: enable_deterministic_inference = False - with patch( - "sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args", - return_value=MockServerArgs(), - ): + with patch('sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args', return_value=MockServerArgs()): runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) - combine_input = runner.run(dispatch_output, quant_info, lora_info) - lora_output = combine_input + # Run SGLang runner (Uses Kernel) + lora_output = runner.run(dispatch_output, quant_info, lora_info) + # Run Naive Torch Implementation (Uses dense mapping for verification) torch_output = torch_naive_moe_with_lora( - hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info + hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info, token_lora_mapping ) print(f"lora_output.hidden_states mean: {lora_output.hidden_states.mean()}") print(f"torch_output mean: {torch_output.mean()}") - # Assert close - torch.testing.assert_close( - lora_output.hidden_states, torch_output, atol=1e-2, rtol=1e-2 - ) + torch.testing.assert_close(lora_output.hidden_states, torch_output, atol=1e-2, rtol=1e-2) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/manual/lora/test_moe_lora_align_sum.py index d378b33a3ca8..16ee2509b115 100644 --- a/test/manual/lora/test_moe_lora_align_sum.py +++ b/test/manual/lora/test_moe_lora_align_sum.py @@ -4,26 +4,13 @@ import pytest import torch from torch.utils.cpp_extension import load +import sys -# ============================================================================== -# 1. JIT Compile the Kernel -# ============================================================================== -# Pointing specifically to the path you provided -source_path = ".." +# --------------------------------------------------------- +# IMPORT PREBUILT KERNEL +# --------------------------------------------------------- +from sgl_kernel import moe_lora_align_block_size -print(f"Loading kernel from: {source_path}") - -# Check if file exists to avoid confusing compilation errors -if not os.path.exists(source_path): - raise FileNotFoundError(f"Could not find CUDA file at {source_path}") - -moe_ops = load( - name="moe_lora_ops_jit", - sources=[source_path], - extra_cuda_cflags=["-O3"], - verbose=True, -) -print("Kernel loaded successfully.") def round_up(x, base): @@ -35,20 +22,43 @@ def CEILDIV(x, y): def sample_data(num_experts, max_loras, num_tokens, topk_num): + # 1. Generate TopK IDs (Flattened tokens) topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32) - token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32) - for i in range(num_tokens): pool = list(range(num_experts)) random.shuffle(pool) for j in range(topk_num): topk_ids[i, j] = pool[j] - token_lora_mapping[i] = random.randint(0, max_loras - 1) - return topk_ids.to("cuda"), token_lora_mapping.to("cuda") + # 2. Generate Random Requests (Segments) + # We split num_tokens into random chunks to simulate a batch of requests + remaining_tokens = num_tokens + seg_lens = [] + while remaining_tokens > 0: + # Random length between 1 and remaining + length = random.randint(1, min(32, remaining_tokens)) + if remaining_tokens - length < 0: + length = remaining_tokens + seg_lens.append(length) + remaining_tokens -= length + + # Ensure we cover the full range exactly (cleanup last segment) + if sum(seg_lens) < num_tokens: + seg_lens.append(num_tokens - sum(seg_lens)) + # 3. Build seg_indptr [0, len1, len1+len2, ...] + seg_indptr = torch.cumsum( + torch.tensor([0] + seg_lens, dtype=torch.int32), dim=0 + ).to(dtype=torch.int32) -@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 + # 4. Assign a LoRA ID to each Request + num_reqs = len(seg_lens) + req_to_lora = torch.randint(0, max_loras, (num_reqs,), dtype=torch.int32) + + return (topk_ids.to("cuda"), seg_indptr.to("cuda"), req_to_lora.to("cuda")) + + +@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) @pytest.mark.parametrize("topk_num", [6]) @pytest.mark.parametrize("num_experts", [64, 128, 256, 512]) @pytest.mark.parametrize("max_loras", [2, 32]) @@ -58,7 +68,10 @@ def test_moe_lora_align_block_size( ): # sample data random.seed(1) - topk_ids, token_lora_mapping = sample_data( + torch.manual_seed(1) + + # UPDATED: Get the new 3-step mapping tensors + topk_ids, seg_indptr, req_to_lora = sample_data( num_experts, max_loras, num_tokens, topk_num ) @@ -81,10 +94,11 @@ def test_moe_lora_align_block_size( adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda") lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda") - # call kernel - moe_ops.moe_lora_align_block_size( + # UPDATED: Call kernel with new signature + moe_lora_align_block_size( topk_ids, - token_lora_mapping, + seg_indptr, # Arg 2: Pointers + req_to_lora, # Arg 3: Request Map num_experts, block_size, max_loras, @@ -102,14 +116,49 @@ def test_moe_lora_align_block_size( expert_ids = expert_ids.view(max_loras, -1) sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size) + # Reconstruct token-level ownership for verification logic + # We expand req_to_lora back to [num_tokens] on CPU just to check correctness + # This proves the kernel (which used the compressed format) produced the right result + cpu_seg_indptr = seg_indptr.cpu() + cpu_req_to_lora = req_to_lora.cpu() + token_ownership = torch.zeros(num_tokens, dtype=torch.int32) + + for r in range(len(cpu_req_to_lora)): + start = cpu_seg_indptr[r] + end = cpu_seg_indptr[r + 1] + token_ownership[start:end] = cpu_req_to_lora[r] + + token_ownership = token_ownership.to("cuda") + for lora_idx in range(max_loras): + # Count how many tokens actually belong to this LoRA + expected_count = (token_ownership == lora_idx).sum().item() + + # Verify the kernel processed a reasonable number of tokens (sanity check) + # Note: num_tokens_post_pad includes padding, so it might be larger than expected_count + assert num_tokens_post_pad[lora_idx].item() >= expected_count * topk_num + for token_idx in range(sorted_token_ids.size(1)): block = sorted_token_ids[lora_idx][token_idx] + # Valid indices are those less than total numel indices = block[block != topk_ids.numel()] + if indices.numel() > 0: + # 1. Verify routing: Does the token actually route to this expert? expert_id = expert_ids[lora_idx][token_idx] assert torch.all(topk_ids.view(-1)[indices] == expert_id) + # 2. Verify ownership: Did the kernel grab the correct tokens for this LoRA? + # The indices in 'sorted_token_ids' point to the flattened [token, topk] array. + # We divide by topk_num to get the original token index. + original_token_indices = indices // topk_num + + # Check that all tokens in this block truly belong to 'lora_idx' + actual_owners = token_ownership[original_token_indices] + assert torch.all( + actual_owners == lora_idx + ), f"Kernel put tokens from LoRA {actual_owners} into block for LoRA {lora_idx}" + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index a27e6a60c12c..7edad8e5a0f7 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -27,7 +27,9 @@ python -m unittest test_lora_hf_sgl_logprob_diff """ +import json import multiprocessing as mp +import os import unittest from typing import Any, Dict, List, Optional, Tuple @@ -100,6 +102,16 @@ def compare_logprobs_for_type( Returns: Dictionary containing comparison statistics """ + # It seems like HF is returning logprob for EOS, but SGLang is not. + min_len = min(sglang_logprobs.shape[0], hf_logprobs.shape[0]) + if sglang_logprobs.shape[0] != hf_logprobs.shape[0]: + print( + f"Warning: {logprob_type} logprob shape mismatch: SGLang {sglang_logprobs.shape}, " + f"HF {hf_logprobs.shape}. Truncating to length {min_len}." + ) + sglang_logprobs = sglang_logprobs[:min_len] + hf_logprobs = hf_logprobs[:min_len] + diff = torch.abs(sglang_logprobs - hf_logprobs) max_diff = torch.max(diff).item() mean_diff = torch.mean(diff).item() @@ -536,7 +548,7 @@ def test_moe_lora_logprob_comparison_basic(self): """ model_path = "Qwen/Qwen1.5-MoE-A2.7B" - lora_paths = ["sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest"] + lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] prompts = DEFAULT_TEST_PROMPTS[:2] # Use first 2 default prompts for basic test self._run_comparison_test( @@ -544,6 +556,7 @@ def test_moe_lora_logprob_comparison_basic(self): lora_paths=lora_paths, prompts=prompts, max_new_tokens=32, + lora_backend="triton", ) def test_moe_lora_logprob_comparison_full(self): @@ -552,14 +565,23 @@ def test_moe_lora_logprob_comparison_full(self): """ model_path = "Qwen/Qwen1.5-MoE-A2.7B" - lora_paths = ["sai-lakkshmii/Qwen1.5-MoE-A2.7B-squad-lora-latest"] - prompts = DEFAULT_TEST_PROMPTS + lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] + + # Load prompts from JSON file + prompts_path = os.path.join( + os.path.dirname(__file__), + "prompts", + "sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json", + ) + with open(prompts_path, "r") as f: + prompts = json.load(f) self._run_comparison_test( model_path=model_path, lora_paths=lora_paths, prompts=prompts, max_new_tokens=32, + lora_backend="triton", ) From cf63435bead42e0317e6c57c102d0172a6c611d6 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 6 Feb 2026 12:14:02 -0500 Subject: [PATCH 091/150] fix --- python/sglang/srt/lora/triton_ops/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index 4976a072a220..5d4684fbba4e 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -3,7 +3,6 @@ from .embedding_lora_a import embedding_lora_a_fwd from .fused_moe_lora_kernel import fused_moe_lora from .gate_up_lora_b import gate_up_lora_b_fwd -from .per_expert_lora_moe import per_expert_lora_forward from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_b import sgemm_lora_b_fwd @@ -15,7 +14,6 @@ "sgemm_lora_b_fwd", "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", - "per_expert_lora_forward", "fused_moe_lora", "embedding_lora_a_fwd", ] From 29ceca1ee19b690412f7c4fe47d10b666da8689a Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 6 Feb 2026 12:15:04 -0500 Subject: [PATCH 092/150] fix --- python/sglang/srt/lora/layers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 05d29483bcdc..5689d4f7ed67 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -738,8 +738,6 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: - # FusedMoE is now imported at the top of the file - # FusedMoEWithLoRA is now defined in this file supported_layer_types = { # the order matters From 26709e853a1d7b2dfad19cb7084a678f9ddb6742 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 6 Feb 2026 12:17:31 -0500 Subject: [PATCH 093/150] lint --- python/sglang/srt/lora/layers.py | 19 +- python/sglang/srt/lora/lora_manager.py | 10 +- python/sglang/srt/lora/lora_moe_runners.py | 16 +- .../lora/triton_ops/fused_moe_lora_kernel.py | 25 +- .../csrc/moe/moe_lora_align_sum_kernel.cu | 857 ++++++++++-------- .../manual/lora/test_fused_moe_lora_kernel.py | 34 +- test/manual/lora/test_lora_moe_runner.py | 154 +++- test/manual/lora/test_moe_lora_align_sum.py | 5 +- 8 files changed, 663 insertions(+), 457 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 5689d4f7ed67..82c2d757eea9 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -16,12 +16,12 @@ QKVParallelLinear, RowParallelLinear, ) +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo @@ -614,10 +614,11 @@ def __init__( # initialize triton_lora moe runner for batches with lora enabled from sglang.srt.layers.moe.moe_runner.runner import MoeRunner from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + self._lora_runner = MoeRunner( base_layer.quant_method.runner.runner_backend, base_layer.moe_runner_config, - lora_enabled=True + lora_enabled=True, ) # Pre-compute quant info for efficiency (weights don't change during inference) @@ -642,9 +643,7 @@ def set_lora_info( self.down_lora_a_weights = down_lora_a_weights self.down_lora_b_weights = down_lora_b_weights - def _get_lora_info( - self - ): + def _get_lora_info(self): """ Build LoRAInfo for the current batch. @@ -662,7 +661,9 @@ def _get_lora_info( # Create adapter_enabled tensor for the current batch # Only enable LoRA adapters that are actually used in this batch # TODO: Jonahbernard: check that this doesn't slow down inference for this batch - adapter_enabled = torch.zeros(len(lora_ranks), dtype=torch.int32, device=lora_ranks.device) + adapter_enabled = torch.zeros( + len(lora_ranks), dtype=torch.int32, device=lora_ranks.device + ) adapter_enabled.index_fill_(0, batch_info.weight_indices.long(), 1) return LoRAInfo( @@ -670,8 +671,8 @@ def _get_lora_info( gate_up_lora_b_weights=self.gate_up_lora_b_weights, down_lora_a_weights=self.down_lora_a_weights, down_lora_b_weights=self.down_lora_b_weights, - seg_indptr = batch_info.seg_indptr, - req_to_lora = batch_info.weight_indices, + seg_indptr=batch_info.seg_indptr, + req_to_lora=batch_info.weight_indices, lora_ranks=lora_ranks, adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index fcf2c017193d..5101162a1ff8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -21,6 +21,7 @@ import torch from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -28,7 +29,7 @@ ) from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.lora_registry import get_backend_from_name -from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer +from sglang.srt.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_registry import LoRARef @@ -43,8 +44,6 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import replace_submodule from sglang.srt.utils.hf_transformers_utils import AutoConfig -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.lora.layers import FusedMoEWithLoRA logger = logging.getLogger(__name__) @@ -293,7 +292,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): use_cuda_graph=use_cuda_graph, ) - def update_lora_info(self): """ Update all LoRA modules to associate them with the latest memory buffer. @@ -378,8 +376,8 @@ def init_state( the target modules and max_lora_rank. """ - assert ( - lora_paths or (max_lora_rank is not None and target_modules is not None) + assert lora_paths or ( + max_lora_rank is not None and target_modules is not None ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." self.init_lora_adapters(lora_paths) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index e060470b39f4..2708d3adadf5 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -38,13 +38,11 @@ ) from sglang.srt.utils import is_cuda, is_hip - - _is_hip = is_hip() _is_cuda = is_cuda() if _is_cuda or _is_hip: - from sgl_kernel import gelu_and_mul, silu_and_mul, moe_lora_align_block_size + from sgl_kernel import gelu_and_mul, moe_lora_align_block_size, silu_and_mul @dataclass @@ -216,7 +214,6 @@ def run( dtype=hidden_states.dtype, ) - invoke_fused_moe_kernel( hidden_states, w13, @@ -281,7 +278,6 @@ def run( # Get token-to-LoRA mapping from lora_info lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) - moe_lora_align_block_size( topk_ids, lora_info.seg_indptr, @@ -368,7 +364,6 @@ def run( else: out_hidden_states = torch.empty_like(hidden_states) - invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -412,7 +407,6 @@ def run( num_tokens_post_padded_lora=num_tokens_post_padded_lora, ) - # ============================================================ # Stage 4: Final reduction (sum across top_k) # ============================================================ @@ -483,7 +477,6 @@ def _add_lora_gate_up_delta( M, top_k, gate_up_dim = intermediate_cache.shape - # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return @@ -495,8 +488,9 @@ def _add_lora_gate_up_delta( max_loras = len(lora_info.lora_ranks) - lora_ids = torch.arange(max_loras, dtype=torch.int32, device=hidden_states.device) - + lora_ids = torch.arange( + max_loras, dtype=torch.int32, device=hidden_states.device + ) fused_moe_lora( output=intermediate_cache, @@ -548,7 +542,6 @@ def _add_lora_down_delta( M, top_k, hidden_dim = intermediate_cache.shape - # Skip LoRA computation if no LoRA adapters have non-zero rank if lora_info.max_lora_rank == 0: return @@ -563,7 +556,6 @@ def _add_lora_down_delta( device = intermediate_cache.device lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) - fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index d11c1c7f5205..fec8c8b24631 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -3,6 +3,7 @@ import torch import triton import triton.language as tl +from sgl_kernel.utils import is_arch_support_pdl from sglang.srt.distributed import ( tensor_model_parallel_all_gather, @@ -11,8 +12,6 @@ # Import SGLang's standard PDL support detection -from sgl_kernel.utils import is_arch_support_pdl - _LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {} @@ -99,8 +98,7 @@ def _fused_moe_lora_kernel( lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) - USE_GDC = False # TODO (Jonahcb): remove this - + USE_GDC = False # TODO (Jonahcb): remove this if lora_id == -1: # Early exit for the no-lora case. @@ -134,7 +132,6 @@ def _fused_moe_lora_kernel( if expert_id == -1: return - # get a_ptr,b_ptr,c_ptr cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) @@ -144,7 +141,6 @@ def _fused_moe_lora_kernel( offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) # ================================================================= secure - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) token_ind = stride_tl * lora_id + offs_token_id offs_token = tl.load( @@ -154,7 +150,6 @@ def _fused_moe_lora_kernel( # ================================================================= secure - # get a_ptrs,b_ptrs a_ptrs = cur_a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak @@ -168,8 +163,6 @@ def _fused_moe_lora_kernel( + offs_bn[None, :] * stride_bn ) - - if USE_GDC and IS_PRIMARY: # GDC launch dependents hints the runtime system to launch dependent kernels. tl.extra.cuda.gdc_launch_dependents() @@ -179,19 +172,13 @@ def _fused_moe_lora_kernel( # accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # ================================================================= secure - - - # GDC wait waits for ALL programs in the prior kernel to complete # before continuing. if USE_GDC and not IS_PRIMARY: tl.extra.cuda.gdc_wait() - - for k in range(0, grid_k): k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) # pre-fetch lora weight @@ -257,7 +244,7 @@ def _fused_moe_lora_shrink( w1_lora_a_stacked = lora_a_stacked[0] # TODO (Jonahcb): investigate why relying on is_arch_support_pdl() is causing crash inside kernel - #use_gdc = is_arch_support_pdl() + # use_gdc = is_arch_support_pdl() use_gdc = False shrink_config = { "BLOCK_SIZE_M": block_size_m, @@ -463,7 +450,10 @@ def _fused_moe_lora( == qcurr_hidden_states.dim() == 2 ) - if sorted_token_ids.shape[0] != expert_ids.shape[0] or sorted_token_ids.shape[0] != num_tokens_post_padded.shape[0]: + if ( + sorted_token_ids.shape[0] != expert_ids.shape[0] + or sorted_token_ids.shape[0] != num_tokens_post_padded.shape[0] + ): x = 1 assert ( sorted_token_ids.shape[0] @@ -495,7 +485,6 @@ def _fused_moe_lora( device=device, ) - _fused_moe_lora_shrink( a_intermediate_cache1, qcurr_hidden_states, diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu index b88a8f5c7467..e30bd60dcdae 100644 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -1,11 +1,10 @@ // Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu -#include #include #include +#include #include - #include // ================================================================ @@ -24,40 +23,40 @@ __host__ __device__ inline T round_to_next_multiple_of(T x, T y) { } // Minimal Dispatch Macros to avoid compiling full utils -#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Int: { \ - using scalar_t = int32_t; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::Long: { \ - using scalar_t = int64_t; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Int: { \ + using scalar_t = int32_t; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::Long: { \ + using scalar_t = int64_t; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ TORCH_CHECK(false, #NAME " not implemented for ", toString(TYPE)); \ } -#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ +#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ TORCH_CHECK(false, #NAME " not implemented for ", toString(TYPE)); \ } // ================================================================ @@ -71,26 +70,26 @@ namespace batched_moe_align_block_size { static constexpr int32_t num_threads = 1024; static constexpr int32_t num_blocks = 1; __global__ void batched_moe_align_block_size_kernel( - int32_t const num_batches, int32_t const max_tokens_per_batch, - int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, - int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, + int32_t const num_batches, + int32_t const max_tokens_per_batch, + int32_t const block_size, + int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ block_ids, int32_t* __restrict__ num_tokens_post_pad) { // TODO(varun): This is a naive implementation. Could be optimized. size_t const batch_id = threadIdx.x; size_t const stride = blockDim.x * gridDim.x; - int32_t const num_blocks_per_batch = - CEILDIV(max_tokens_per_batch, block_size); - int32_t const sorted_ids_size = - num_blocks_per_batch * num_batches * block_size; + int32_t const num_blocks_per_batch = CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = num_blocks_per_batch * num_batches * block_size; int32_t const block_ids_size = sorted_ids_size / block_size; - int32_t const SENTINEL = - num_batches * max_tokens_per_batch; // To denote invalid entries. - // Intialize sorted_ids + int32_t const SENTINEL = num_batches * max_tokens_per_batch; // To denote invalid entries. + // Initialize sorted_ids for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { sorted_ids[i] = SENTINEL; } - // Intialize expert_ids with -1 + // Initialize expert_ids with -1 for (size_t i = threadIdx.x; i < block_ids_size; i += stride) { block_ids[i] = -1; } @@ -99,8 +98,7 @@ __global__ void batched_moe_align_block_size_kernel( if (batch_id < num_batches) { b_num_tokens = batch_num_tokens[batch_id]; } - int32_t const ceil_b_num_tokens = - CEILDIV(b_num_tokens, block_size) * block_size; + int32_t const ceil_b_num_tokens = CEILDIV(b_num_tokens, block_size) * block_size; // Compute prefix sum over token counts per expert using BlockScan = cub::BlockScan; @@ -132,13 +130,23 @@ __global__ void batched_moe_align_block_size_kernel( template __device__ void _moe_align_block_size( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, - int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, - int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id, - int32_t topk_num, int32_t* token_mask, bool has_expert_map) { + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t padded_num_experts, + int32_t experts_per_warp, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, + int32_t model_offset, + int32_t inactive_expert_id, + int32_t topk_num, + int32_t* token_mask, + bool has_expert_map) { extern __shared__ int32_t shared_counts[]; // Compute input buffer offsets. Typically these will all be 0, except when @@ -151,8 +159,7 @@ __device__ void _moe_align_block_size( // This is safe since the current kernel does not use sorted_token_ids. if (blockIdx.x % 2) { // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; - it += blockDim.x) { + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { sorted_token_ids[sorted_token_ids_offset + it] = numel; } return; @@ -185,8 +192,7 @@ __device__ void _moe_align_block_size( int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], - mask); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], mask); } __syncthreads(); @@ -217,15 +223,13 @@ __device__ void _moe_align_block_size( __syncthreads(); if (threadIdx.x < num_experts) { - for (int i = cumsum[cumsum_offset + threadIdx.x]; - i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + for (int i = cumsum[cumsum_offset + threadIdx.x]; i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; } } // Fill remaining expert_ids with 0 - const size_t fill_start_idx = - cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { expert_ids[expert_ids_offset + i] = inactive_expert_id; } @@ -234,12 +238,20 @@ __device__ void _moe_align_block_size( template __device__ void _moe_align_block_size_small_batch_expert( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, - size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks, - int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num, - int32_t* token_mask, bool has_expert_map) { + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, + int32_t inactive_expert_id, + int32_t model_offset, + int32_t topk_num, + int32_t* token_mask, + bool has_expert_map) { // Compute input buffer offsets. Typically these will all be 0, except when // using Multi LoRA. int sorted_token_ids_offset = max_num_tokens_padded * model_offset; @@ -251,8 +263,7 @@ __device__ void _moe_align_block_size_small_batch_expert( // synchronization easier. if (threadIdx.x < fill_threads) { // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; - it += fill_threads) { + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) { sorted_token_ids[sorted_token_ids_offset + it] = numel; } // Three __syncthreads() corresponding to the other threads @@ -289,8 +300,7 @@ __device__ void _moe_align_block_size_small_batch_expert( if (tid < num_experts) { tokens_cnts[tid] = 0; for (int i = 1; i <= stride; ++i) { - tokens_cnts[i * num_experts + tid] += - tokens_cnts[(i - 1) * num_experts + tid]; + tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; } } @@ -299,13 +309,9 @@ __device__ void _moe_align_block_size_small_batch_expert( if (tid == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = - cumsum[i - 1] + - CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * - block_size; + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size; } - total_tokens_post_pad[model_offset] = - static_cast(cumsum[num_experts]); + total_tokens_post_pad[model_offset] = static_cast(cumsum[num_experts]); } __syncthreads(); @@ -329,8 +335,7 @@ __device__ void _moe_align_block_size_small_batch_expert( // filter invalid expert if (expert_id == -1) continue; } - int32_t rank_post_pad = - tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; if (token_mask == nullptr || token_mask[i / topk_num]) { sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; @@ -342,10 +347,16 @@ __device__ void _moe_align_block_size_small_batch_expert( template __device__ void _count_and_sort_expert_tokens( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, - int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask, - int32_t model_offset, int32_t topk_num, bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t* __restrict__ token_mask, + int32_t model_offset, + int32_t topk_num, + bool has_expert_map) { const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.y; @@ -362,10 +373,8 @@ __device__ void _count_and_sort_expert_tokens( } if (token_mask == nullptr || token_mask[i / topk_num]) { - int32_t rank_post_pad = atomicAdd( - &cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); - sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = - i; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i; } } } @@ -373,28 +382,63 @@ __device__ void _count_and_sort_expert_tokens( template __global__ void moe_align_block_size_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, - int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, - int32_t topk_num, bool has_expert_map) { + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t padded_num_experts, + int32_t experts_per_warp, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + int32_t max_num_tokens_padded, + int32_t topk_num, + bool has_expert_map) { _moe_align_block_size( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, padded_num_experts, experts_per_warp, block_size, numel, - cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), - 0, 0, topk_num, nullptr, has_expert_map); + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + numel, + cumsum, + max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), + 0, + 0, + topk_num, + nullptr, + has_expert_map); } template __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, - int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t topk_num, + bool has_expert_map) { _count_and_sort_expert_tokens( - topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, - max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map); + topk_ids, + sorted_token_ids, + cumsum_buffer, + expert_map, + numel, + num_experts, + max_num_tokens_padded, + nullptr, + 0, + topk_num, + has_expert_map); } template @@ -416,29 +460,58 @@ __global__ void moe_sum_kernel( template __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, - size_t numel, int32_t max_num_tokens_padded, int32_t topk_num, + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t max_num_tokens_padded, + int32_t topk_num, bool has_expert_map) { _moe_align_block_size_small_batch_expert( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, block_size, numel, max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + block_size, + numel, + max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), + 0, + 0, + topk_num, + nullptr, has_expert_map); } template __global__ void moe_lora_align_block_size_kernel( - scalar_t* __restrict__ topk_ids, int32_t* __restrict__ seg_indptr, int32_t* __restrict__ req_to_lora, - int num_reqs, int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, - int max_loras, size_t numel, int max_num_tokens_padded, - int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, int32_t topk_num, - int32_t* total_tokens_post_pad, int32_t* adapter_enabled, - int32_t* __restrict__ cumsum, int32_t experts_per_warp, - int32_t padded_num_experts, int32_t* lora_ids, - int32_t* __restrict__ token_mask, bool has_expert_map) { + scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ seg_indptr, + int32_t* __restrict__ req_to_lora, + int num_reqs, + int64_t block_size, + int32_t* __restrict__ expert_map, + int num_experts, + int max_loras, + size_t numel, + int max_num_tokens_padded, + int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t topk_num, + int32_t* total_tokens_post_pad, + int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, + int32_t experts_per_warp, + int32_t padded_num_experts, + int32_t* lora_ids, + int32_t* __restrict__ token_mask, + bool has_expert_map) { int lora_idx = blockIdx.x / 2; int lora_id = lora_ids[lora_idx]; if (lora_id == -1 || adapter_enabled[lora_id] == 0) { @@ -452,12 +525,12 @@ __global__ void moe_lora_align_block_size_kernel( // 1. Parallel Clear (Reset mask to 0) // All threads help clear the mask for this adapter for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - token_mask[lora_offset + i] = 0; + token_mask[lora_offset + i] = 0; } // Initialize output counter if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; + total_tokens_post_pad[lora_id] = 0; } __syncthreads(); @@ -465,33 +538,53 @@ __global__ void moe_lora_align_block_size_kernel( // 2. Segment-based Fill // Iterate over requests. If a request matches this LoRA, fill its range. for (int r = 0; r < num_reqs; ++r) { - if (req_to_lora[r] == lora_id) { - int start = seg_indptr[r]; - int end = seg_indptr[r+1]; - - // Parallel Fill: All threads help mark this segment as "1" - for (int i = start + threadIdx.x; i < end; i += blockDim.x) { - token_mask[lora_offset + i] = 1; - } + if (req_to_lora[r] == lora_id) { + int start = seg_indptr[r]; + int end = seg_indptr[r + 1]; + + // Parallel Fill: All threads help mark this segment as "1" + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + token_mask[lora_offset + i] = 1; } + } } __syncthreads(); _moe_align_block_size( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, padded_num_experts, experts_per_warp, block_size, numel, - cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num, - &token_mask[(lora_id * num_tokens)], has_expert_map); + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + numel, + cumsum, + max_num_tokens_padded, + max_num_m_blocks, + lora_id, + -1, + topk_num, + &token_mask[(lora_id * num_tokens)], + has_expert_map); } template __global__ void lora_count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, - int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask, - int32_t* lora_ids, bool has_expert_map) { + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t topk_num, + int32_t* token_mask, + int32_t* lora_ids, + bool has_expert_map) { int lora_idx = blockIdx.x; int lora_id = lora_ids[lora_idx]; if (lora_id == -1) { @@ -501,20 +594,40 @@ __global__ void lora_count_and_sort_expert_tokens_kernel( int num_tokens = numel / topk_num; _count_and_sort_expert_tokens( - topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, - max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id, - topk_num, has_expert_map); + topk_ids, + sorted_token_ids, + cumsum_buffer, + expert_map, + numel, + num_experts, + max_num_tokens_padded, + &token_mask[(lora_id * num_tokens)], + lora_id, + topk_num, + has_expert_map); } template __global__ void moe_lora_align_block_size_small_batch_expert_kernel( - scalar_t* __restrict__ topk_ids, int32_t* __restrict__ seg_indptr, int32_t* __restrict__ req_to_lora, - int num_reqs, int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, - int max_loras, size_t numel, int max_num_tokens_padded, - int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, int topk_num, - int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids, - int32_t* token_mask, bool has_expert_map) { + scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ seg_indptr, + int32_t* __restrict__ req_to_lora, + int num_reqs, + int64_t block_size, + int32_t* __restrict__ expert_map, + int num_experts, + int max_loras, + size_t numel, + int max_num_tokens_padded, + int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int topk_num, + int32_t* total_tokens_post_pad, + int32_t* adapter_enabled, + int32_t* lora_ids, + int32_t* token_mask, + bool has_expert_map) { int lora_idx = blockIdx.x; int lora_id = lora_ids[lora_idx]; if (lora_id == -1 || adapter_enabled[lora_id] == 0) { @@ -527,12 +640,12 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( // 1. Parallel Clear (Reset mask to 0) // All threads help clear the mask for this adapter for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - token_mask[lora_offset + i] = 0; + token_mask[lora_offset + i] = 0; } // Initialize output counter if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; + total_tokens_post_pad[lora_id] = 0; } __syncthreads(); @@ -540,23 +653,34 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( // 2. Segment-based Fill // Iterate over requests. If a request matches this LoRA, fill its range. for (int r = 0; r < num_reqs; ++r) { - if (req_to_lora[r] == lora_id) { - int start = seg_indptr[r]; - int end = seg_indptr[r+1]; - - // Parallel Fill: All threads help mark this segment as "1" - for (int i = start + threadIdx.x; i < end; i += blockDim.x) { - token_mask[lora_offset + i] = 1; - } + if (req_to_lora[r] == lora_id) { + int start = seg_indptr[r]; + int end = seg_indptr[r + 1]; + + // Parallel Fill: All threads help mark this segment as "1" + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + token_mask[lora_offset + i] = 1; } + } } __syncthreads(); _moe_align_block_size_small_batch_expert( - topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, - num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks, - -1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)], + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + block_size, + numel, + max_num_tokens_padded, + max_num_m_blocks, + -1, + lora_id, + topk_num, + &token_mask[(lora_id * num_tokens)], has_expert_map); } @@ -564,24 +688,24 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( // taken from // https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad, - std::optional maybe_expert_map) { +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = - ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; int experts_per_warp = WARP_SIZE; int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, - "padded_num_experts must be less than 1024"); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); + auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); bool has_expert_map = maybe_expert_map.has_value(); torch::Tensor expert_map; if (has_expert_map) { @@ -590,86 +714,89 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, expert_map = torch::empty({0}, options_int); } - DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `cumsum` tensors - bool small_batch_expert_mode = - (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t threads = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem_size = - ((threads + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - auto small_batch_expert_kernel = - moe::moe_align_block_size_small_batch_expert_kernel< - scalar_t, fill_threads>; - small_batch_expert_kernel<<<1, fill_threads + threads, - shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), num_experts, block_size, - topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1), - has_expert_map); - } else { - torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); - auto align_kernel = moe::moe_align_block_size_kernel; - - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); - size_t shared_mem_size = - num_warps * experts_per_warp * sizeof(int32_t); - - // launch two threadblocks - // blockIdx.x == 0: counting experts and aligning - // blockIdx.x == 1: filling sorted_token_ids - align_kernel<<<2, threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), num_experts, padded_num_experts, - experts_per_warp, block_size, topk_ids.numel(), - cumsum_buffer.data_ptr(), sorted_token_ids.size(0), - topk_ids.size(1), has_expert_map); - - const int block_threads = std::min(256, (int)threads); - const int num_blocks = - (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - dim3 gridDims(1, actual_blocks); - - auto sort_kernel = - moe::count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), expert_map.data_ptr(), - topk_ids.numel(), num_experts, sorted_token_ids.size(0), - topk_ids.size(1), has_expert_map); - } - }); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + auto small_batch_expert_kernel = moe::moe_align_block_size_small_batch_expert_kernel; + small_batch_expert_kernel<<<1, fill_threads + threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + sorted_token_ids.size(0), + topk_ids.size(1), + has_expert_map); + } else { + torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int); + auto align_kernel = moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); + + // launch two threadblocks + // blockIdx.x == 0: counting experts and aligning + // blockIdx.x == 1: filling sorted_token_ids + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr(), + sorted_token_ids.size(0), + topk_ids.size(1), + has_expert_map); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + dim3 gridDims(1, actual_blocks); + + auto sort_kernel = moe::count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + expert_map.data_ptr(), + topk_ids.numel(), + num_experts, + sorted_token_ids.size(0), + topk_ids.size(1), + has_expert_map); + } + }); } -void batched_moe_align_block_size(int64_t max_tokens_per_batch, - int64_t block_size, - torch::Tensor const& batch_num_tokens, - torch::Tensor sorted_ids, - torch::Tensor batch_ids, - torch::Tensor num_tokens_post_pad) { +void batched_moe_align_block_size( + int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { namespace batched_kernel = moe::batched_moe_align_block_size; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int32_t const B = batch_num_tokens.size(0); - int32_t const num_blocks_per_batch = - round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks_per_batch = round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; int32_t const num_blocks = num_blocks_per_batch * B; int64_t const sorted_ids_size = num_blocks * block_size; @@ -678,15 +805,20 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, TORCH_CHECK(num_tokens_post_pad.size(0) == 1); TORCH_CHECK(B <= batched_kernel::num_threads); - batched_kernel::batched_moe_align_block_size_kernel<<< - batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( - B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr(), - sorted_ids.data_ptr(), batch_ids.data_ptr(), - num_tokens_post_pad.data_ptr()); + batched_kernel:: + batched_moe_align_block_size_kernel<<>>( + B, + max_tokens_per_batch, + block_size, + batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), + batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr()); } -void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] - torch::Tensor& output) // [num_tokens, hidden_size] +void moe_sum( + torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); const auto num_tokens = output.numel() / hidden_size; @@ -700,25 +832,22 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] switch (topk) { case 2: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel<<>>( - output.data_ptr(), input.data_ptr(), - hidden_size); + moe::moe_sum_kernel + <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 3: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel<<>>( - output.data_ptr(), input.data_ptr(), - hidden_size); + moe::moe_sum_kernel + <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 4: DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel<<>>( - output.data_ptr(), input.data_ptr(), - hidden_size); + moe::moe_sum_kernel + <<>>(output.data_ptr(), input.data_ptr(), hidden_size); }); break; @@ -729,33 +858,36 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] } void moe_lora_align_block_size( - torch::Tensor topk_ids, torch::Tensor seg_indptr, torch::Tensor req_to_lora, - int64_t num_experts, int64_t block_size, int64_t max_loras, - int64_t max_num_tokens_padded, int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, - torch::Tensor lora_ids, std::optional maybe_expert_map) { + torch::Tensor topk_ids, + torch::Tensor seg_indptr, + torch::Tensor req_to_lora, + int64_t num_experts, + int64_t block_size, + int64_t max_loras, + int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor adapter_enabled, + torch::Tensor lora_ids, + std::optional maybe_expert_map) { const int topk_num = topk_ids.size(1); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); int device_max_shared_mem; auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = - ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, - "padded_num_experts must be less than 1024"); + TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor token_mask = - torch::empty({max_loras * topk_ids.size(0)}, options_int); + auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor token_mask = torch::empty({max_loras * topk_ids.size(0)}, options_int); bool has_expert_map = maybe_expert_map.has_value(); torch::Tensor expert_map; if (has_expert_map) { @@ -765,86 +897,103 @@ void moe_lora_align_block_size( } int num_reqs = seg_indptr.size(0) - 1; - DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { - bool small_batch_expert_mode = - (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t num_thread = max((int32_t)num_experts, 128); - const int32_t shared_mem = - (num_thread + 1) * num_experts * sizeof(int32_t) + - (num_experts + 1) * sizeof(int32_t); - if (shared_mem > device_max_shared_mem) { - TORCH_CHECK(false, "Shared memory usage exceeds device limit."); - } - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - - dim3 blockDim(num_thread + fill_threads); - auto kernel = - moe::moe_lora_align_block_size_small_batch_expert_kernel< - scalar_t, fill_threads>; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); - kernel<<>>( - topk_ids.data_ptr(), - seg_indptr.data_ptr(), req_to_lora.data_ptr(), num_reqs, block_size, - expert_map.data_ptr(), num_experts, max_loras, - topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), lora_ids.data_ptr(), - token_mask.data_ptr(), has_expert_map); - } else { - int num_thread = 1024; - dim3 blockDim(num_thread); - size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); - - size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); - - // cumsum buffer - torch::Tensor cumsum = - torch::zeros({max_loras * (num_experts + 1)}, options_int); - - auto align_kernel = - moe::moe_lora_align_block_size_kernel; - - // launch two threadblocks for each lora - // blockIdx.x % 2 == 0: counting experts and aligning - // blockIdx.x % 2 == 1: filling sorted_token_ids - align_kernel<<>>( - topk_ids.data_ptr(), - seg_indptr.data_ptr(), req_to_lora.data_ptr(), num_reqs, block_size, - expert_map.data_ptr(), num_experts, max_loras, - topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), cumsum.data_ptr(), - WARP_SIZE, padded_num_experts, lora_ids.data_ptr(), - token_mask.data_ptr(), has_expert_map); - - const int block_threads = std::min(256, (int)num_thread); - const int num_blocks = - (topk_ids.numel() + block_threads - 1) / block_threads; - - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - dim3 gridDims(max_loras, actual_blocks); - auto sort_kernel = - moe::lora_count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), cumsum.data_ptr(), - expert_map.data_ptr(), topk_ids.numel(), num_experts, - max_num_tokens_padded, topk_num, token_mask.data_ptr(), - lora_ids.data_ptr(), has_expert_map); - } - }); -} + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = max((int32_t)num_experts, 128); + const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, "Shared memory usage exceeds device limit."); + } + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); + kernel<<>>( + topk_ids.data_ptr(), + seg_indptr.data_ptr(), + req_to_lora.data_ptr(), + num_reqs, + block_size, + expert_map.data_ptr(), + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), + lora_ids.data_ptr(), + token_mask.data_ptr(), + has_expert_map); + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + // cumsum buffer + torch::Tensor cumsum = torch::zeros({max_loras * (num_experts + 1)}, options_int); + + auto align_kernel = moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + align_kernel<<>>( + topk_ids.data_ptr(), + seg_indptr.data_ptr(), + req_to_lora.data_ptr(), + num_reqs, + block_size, + expert_map.data_ptr(), + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), + cumsum.data_ptr(), + WARP_SIZE, + padded_num_experts, + lora_ids.data_ptr(), + token_mask.data_ptr(), + has_expert_map); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel; + + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum.data_ptr(), + expert_map.data_ptr(), + topk_ids.numel(), + num_experts, + max_num_tokens_padded, + topk_num, + token_mask.data_ptr(), + lora_ids.data_ptr(), + has_expert_map); + } + }); +} diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/manual/lora/test_fused_moe_lora_kernel.py index e876aca6e583..8d50ed95e203 100644 --- a/test/manual/lora/test_fused_moe_lora_kernel.py +++ b/test/manual/lora/test_fused_moe_lora_kernel.py @@ -1,24 +1,17 @@ # adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_fused_moe_lora_kernel.py -import sys -import os import random + import pytest import torch -from sglang.srt.distributed import ( - init_distributed_environment, - initialize_model_parallel, -) -from sglang.srt.distributed.parallel_state import ( - get_tensor_model_parallel_world_size, -) -from sglang.srt.lora.triton_ops import fused_moe_lora -from sglang.srt.utils import set_random_seed - # ============================================================================== # IMPORT PREBUILT KERNEL # ============================================================================== from sgl_kernel import moe_lora_align_block_size + +from sglang.srt.lora.triton_ops import fused_moe_lora +from sglang.srt.utils import set_random_seed + # ============================================================================== @@ -153,11 +146,11 @@ def use_fused_moe_lora_kernel( # init output tensors sorted_token_ids = torch.empty( - (max_loras * max_num_tokens_padded,), - dtype=torch.int32, - device=device + (max_loras * max_num_tokens_padded,), dtype=torch.int32, device=device + ) + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), dtype=torch.int32, device=device ) - expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32, device=device) num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32, device=device) adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device=device) @@ -177,7 +170,7 @@ def use_fused_moe_lora_kernel( num_tokens_post_padded, adapter_enabled, lora_ids, - None # maybe_expert_map + None, # maybe_expert_map ) config = { @@ -303,7 +296,7 @@ def test_fused_moe_lora_kernel( K, ), dtype=dtype, - device=device + device=device, ) ] lora_b_stacked = [ @@ -315,7 +308,7 @@ def test_fused_moe_lora_kernel( max_lora_rank, ), dtype=dtype, - device=device + device=device, ) ] hidden_states = torch.rand( @@ -324,7 +317,7 @@ def test_fused_moe_lora_kernel( K, ), dtype=dtype, - device=device + device=device, ) # fused_moe_lora_kernel output @@ -357,5 +350,6 @@ def test_fused_moe_lora_kernel( torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/manual/lora/test_lora_moe_runner.py index 68616d7aa24f..d3971f75403f 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/manual/lora/test_lora_moe_runner.py @@ -31,7 +31,9 @@ from sglang.srt.utils import set_random_seed -def generate_request_data(num_tokens: int, num_sequences: int, max_loras: int, device="cuda"): +def generate_request_data( + num_tokens: int, num_sequences: int, max_loras: int, device="cuda" +): """ Generates segment-based request data instead of token-based data. """ @@ -47,22 +49,32 @@ def generate_request_data(num_tokens: int, num_sequences: int, max_loras: int, d length = random.randint(1, min(max_len, num_tokens // num_sequences * 2)) seg_lens.append(length) remaining -= length - seg_lens.append(remaining) # Last segment gets the rest + seg_lens.append(remaining) # Last segment gets the rest # 2. Build seg_indptr [0, len1, len1+len2, ...] - seg_indptr = torch.cumsum(torch.tensor([0] + seg_lens, dtype=torch.int32, device=device), dim=0, dtype=torch.int32) + seg_indptr = torch.cumsum( + torch.tensor([0] + seg_lens, dtype=torch.int32, device=device), + dim=0, + dtype=torch.int32, + ) # 3. Assign one LoRA ID per Request - req_to_lora = torch.randint(0, max_loras, (num_sequences,), dtype=torch.int32, device=device) + req_to_lora = torch.randint( + 0, max_loras, (num_sequences,), dtype=torch.int32, device=device + ) # 4. Create dense mapping for the Naive verification function # (Expand req_to_lora based on seg_lens) - token_lora_mapping = torch.repeat_interleave(req_to_lora, torch.tensor(seg_lens, device=device)) + token_lora_mapping = torch.repeat_interleave( + req_to_lora, torch.tensor(seg_lens, device=device) + ) return seg_indptr, req_to_lora, token_lora_mapping -def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32): +def assign_experts_to_tokens( + num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32 +): assert top_k_num <= num_experts, "top_k_num must be <= num_experts" expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) @@ -76,26 +88,54 @@ def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int, return expert_indices, expert_weights -def sample_data(num_tokens: int, num_sequences: int, max_loras: int, num_experts: int, top_k_num: int, dtype=torch.float32, device="cuda"): - topk_ids, topk_weights = assign_experts_to_tokens(num_tokens, num_experts, top_k_num, dtype) - seg_indptr, req_to_lora, token_lora_mapping = generate_request_data(num_tokens, num_sequences, max_loras, device) +def sample_data( + num_tokens: int, + num_sequences: int, + max_loras: int, + num_experts: int, + top_k_num: int, + dtype=torch.float32, + device="cuda", +): + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num, dtype + ) + seg_indptr, req_to_lora, token_lora_mapping = generate_request_data( + num_tokens, num_sequences, max_loras, device + ) return topk_ids, topk_weights, seg_indptr, req_to_lora, token_lora_mapping -def create_lora_info(seg_indptr, weight_indices, topk_ids, max_loras, num_experts, max_lora_rank, hidden_dim, intermediate_dim, gate_up_dim, dtype, device): +def create_lora_info( + seg_indptr, + weight_indices, + topk_ids, + max_loras, + num_experts, + max_lora_rank, + hidden_dim, + intermediate_dim, + gate_up_dim, + dtype, + device, +): # ------------------------------------------------------------------------- # 1. Deterministic LoRA A Initialization # ------------------------------------------------------------------------- val_gate_up_a = 1.0 / hidden_dim gate_up_lora_a_weights = torch.full( (max_loras, num_experts, max_lora_rank, hidden_dim), - val_gate_up_a, dtype=dtype, device=device + val_gate_up_a, + dtype=dtype, + device=device, ) val_down_a = 1.0 / intermediate_dim down_lora_a_weights = torch.full( (max_loras, num_experts, max_lora_rank, intermediate_dim), - val_down_a, dtype=dtype, device=device + val_down_a, + dtype=dtype, + device=device, ) # ------------------------------------------------------------------------- @@ -103,11 +143,15 @@ def create_lora_info(seg_indptr, weight_indices, topk_ids, max_loras, num_expert # ------------------------------------------------------------------------- base_target = 0.01 - gate_up_lora_b_weights = torch.zeros((max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device) - down_lora_b_weights = torch.zeros((max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device) + gate_up_lora_b_weights = torch.zeros( + (max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device + ) + down_lora_b_weights = torch.zeros( + (max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device + ) for i in range(num_experts): - expert_multiplier = (i + 1) + expert_multiplier = i + 1 fill_val = (base_target * expert_multiplier) / max_lora_rank gate_up_lora_b_weights[:, i, :, :] = fill_val @@ -116,7 +160,9 @@ def create_lora_info(seg_indptr, weight_indices, topk_ids, max_loras, num_expert # ------------------------------------------------------------------------- # 3. Setup Metadata # ------------------------------------------------------------------------- - lora_ranks = torch.full((max_loras,), max_lora_rank, dtype=torch.int32, device=device) + lora_ranks = torch.full( + (max_loras,), max_lora_rank, dtype=torch.int32, device=device + ) # Enable all adapters referenced in weight_indices adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32, device=device) @@ -130,7 +176,6 @@ def create_lora_info(seg_indptr, weight_indices, topk_ids, max_loras, num_expert # UPDATED FIELDS seg_indptr=seg_indptr, req_to_lora=weight_indices, - lora_ranks=lora_ranks, adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, @@ -138,7 +183,17 @@ def create_lora_info(seg_indptr, weight_indices, topk_ids, max_loras, num_expert ) -def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info, token_lora_mapping): +def torch_naive_moe_with_lora( + hidden_states, + w13, + w2, + b13, + b2, + topk_weights, + topk_ids, + lora_info, + token_lora_mapping, +): """ Naive implementation. Note: We pass 'token_lora_mapping' explicitly because lora_info no longer contains it, but the naive token-loop logic needs it. @@ -148,10 +203,17 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top num_experts = w13.shape[0] # Expand hidden states for top-k routing - hidden_expanded = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) + hidden_expanded = ( + hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) + ) # 1. Gate/Up Projection (Base) - gate_up_out = torch.zeros(num_tokens * top_k, w13.shape[1], dtype=hidden_states.dtype, device=hidden_states.device) + gate_up_out = torch.zeros( + num_tokens * top_k, + w13.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) for expert_id in range(num_experts): mask = (topk_ids == expert_id).flatten() @@ -168,15 +230,15 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = token_lora_mapping[i] # Use explicit mapping + lora_id = token_lora_mapping[i] # Use explicit mapping # Check if this adapter is enabled/valid if lora_id < len(lora_info.lora_ranks): - lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] - lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] - lora_a_result = lora_a @ hidden_states[i] - lora_b_result = lora_b @ lora_a_result - gate_up_out[i, k] += lora_b_result + lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] + lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] + lora_a_result = lora_a @ hidden_states[i] + lora_b_result = lora_b @ lora_a_result + gate_up_out[i, k] += lora_b_result # 2. Activation gate_up_dim = gate_up_out.shape[-1] @@ -188,10 +250,16 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top intermediate_out = silu_gate * up # 3. Down Projection (Base) - down_out = torch.zeros(num_tokens, top_k, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device) + down_out = torch.zeros( + num_tokens, + top_k, + hidden_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) for expert_id in range(num_experts): - mask = (topk_ids == expert_id) + mask = topk_ids == expert_id if mask.any(): masked_intermediate = intermediate_out[mask] expert_down_result = masked_intermediate @ w2[expert_id].T @@ -204,7 +272,7 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = token_lora_mapping[i] # Use explicit mapping + lora_id = token_lora_mapping[i] # Use explicit mapping if lora_id < len(lora_info.lora_ranks): lora_a = lora_info.down_lora_a_weights[lora_id, expert_id] @@ -224,7 +292,9 @@ def torch_naive_moe_with_lora(hidden_states, w13, w2, b13, b2, topk_weights, top @pytest.mark.parametrize("top_k_num", [1, 2]) @pytest.mark.parametrize("num_experts", [8, 20]) @pytest.mark.parametrize("max_lora_rank", [8, 16]) -def test_lora_moe_runner_multi_expert(num_tokens, top_k_num, num_experts, max_lora_rank): +def test_lora_moe_runner_multi_expert( + num_tokens, top_k_num, num_experts, max_lora_rank +): # Fixed parameters max_loras = 2 hidden_dim = 512 @@ -276,7 +346,9 @@ def test_lora_moe_runner_multi_expert(num_tokens, top_k_num, num_experts, max_lo expert_ids = topk_ids_flat[sorted_indices] num_dispatched = num_tokens * top_k_num - num_tokens_post_padded = torch.tensor([num_dispatched], dtype=torch.int32, device=device) + num_tokens_post_padded = torch.tensor( + [num_dispatched], dtype=torch.int32, device=device + ) runner_input = TritonRunnerInput( hidden_states=hidden_states, @@ -324,20 +396,34 @@ def test_lora_moe_runner_multi_expert(num_tokens, top_k_num, num_experts, max_lo class MockServerArgs: enable_deterministic_inference = False - with patch('sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args', return_value=MockServerArgs()): + with patch( + "sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args", + return_value=MockServerArgs(), + ): runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) # Run SGLang runner (Uses Kernel) lora_output = runner.run(dispatch_output, quant_info, lora_info) # Run Naive Torch Implementation (Uses dense mapping for verification) torch_output = torch_naive_moe_with_lora( - hidden_states, w13, w2, b13, b2, topk_weights, topk_ids, lora_info, token_lora_mapping + hidden_states, + w13, + w2, + b13, + b2, + topk_weights, + topk_ids, + lora_info, + token_lora_mapping, ) print(f"lora_output.hidden_states mean: {lora_output.hidden_states.mean()}") print(f"torch_output mean: {torch_output.mean()}") - torch.testing.assert_close(lora_output.hidden_states, torch_output, atol=1e-2, rtol=1e-2) + torch.testing.assert_close( + lora_output.hidden_states, torch_output, atol=1e-2, rtol=1e-2 + ) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/manual/lora/test_moe_lora_align_sum.py index 16ee2509b115..8c7a0100c7ca 100644 --- a/test/manual/lora/test_moe_lora_align_sum.py +++ b/test/manual/lora/test_moe_lora_align_sum.py @@ -1,10 +1,8 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py import random -import os + import pytest import torch -from torch.utils.cpp_extension import load -import sys # --------------------------------------------------------- # IMPORT PREBUILT KERNEL @@ -12,7 +10,6 @@ from sgl_kernel import moe_lora_align_block_size - def round_up(x, base): return ((x + base - 1) // base) * base From 27336a5656c43f62f7d5dee1768c5871c2ce25ad Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 7 Feb 2026 14:30:07 -0500 Subject: [PATCH 094/150] fix merge conflict --- python/sglang/srt/lora/lora_moe_runners.py | 59 +++++++++++++++++----- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 2708d3adadf5..9363b2929265 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -23,6 +23,7 @@ from __future__ import annotations +import os from dataclasses import dataclass from typing import Optional @@ -36,13 +37,41 @@ TritonRunnerInput, TritonRunnerOutput, ) -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_xpu _is_hip = is_hip() _is_cuda = is_cuda() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_use_aiter = bool(int(os.getenv("SGLANG_USE_AITER", "0"))) +_is_xpu = is_xpu() +_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 + if _is_cuda or _is_hip: - from sgl_kernel import gelu_and_mul, moe_lora_align_block_size, silu_and_mul + from sgl_kernel import gelu_and_mul, silu_and_mul + + if _is_hip: + if _use_aiter: + try: + from aiter import moe_sum + except ImportError: + raise ImportError( + "aiter is required when SGLANG_USE_AITER is set to True" + ) + else: + from vllm import _custom_ops as vllm_ops # moe_sum +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_xpu: + from sgl_kernel import silu_and_mul + + +if _is_cuda or _is_hip or _is_xpu: + from sgl_kernel import ( # noqa: F401 + moe_align_block_size as sgl_moe_align_block_size, + ) + from sgl_kernel import moe_lora_align_block_size @dataclass @@ -157,12 +186,13 @@ def run( tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 ) - # Import functions needed for MoE computation + # TODO: move these functions to the triton runner from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + _swiglu_gpt_oss_sigmoid_alpha, + _swiglu_silu_clamp_mul, invoke_fused_moe_kernel, moe_sum_reduce_torch_compile, moe_sum_reduce_triton, - swiglu_with_alpha_and_limit, ) hidden_states = runner_input.hidden_states @@ -320,26 +350,31 @@ def run( device=hidden_states.device, dtype=hidden_states.dtype, ) - if activation == "silu": if gemm1_alpha is not None: assert gemm1_limit is not None - intermediate_cache2 = swiglu_with_alpha_and_limit( - intermediate_cache1.view(-1, N), - gemm1_alpha, - gemm1_limit, + intermediate_cache2 = _swiglu_gpt_oss_sigmoid_alpha( + intermediate_cache1.view(-1, N), gemm1_alpha, gemm1_limit + ) + elif gemm1_limit is not None: + intermediate_cache2 = _swiglu_silu_clamp_mul( + intermediate_cache1.view(-1, N), gemm1_limit ) - elif _is_cuda or _is_hip: + elif _is_cuda or _is_hip or _is_xpu: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: - raise ValueError(f"Unsupported platform for activation: {activation=}") + vllm_ops.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "gelu": assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" if _is_cuda or _is_hip: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: - raise ValueError(f"Unsupported platform for activation: {activation=}") + vllm_ops.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) else: raise ValueError(f"Unsupported activation: {activation=}") From e062d4f01e627b9a2bf17a131650af76e5dcc710 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 7 Feb 2026 17:48:01 -0500 Subject: [PATCH 095/150] fix comments --- python/sglang/srt/lora/layers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 82c2d757eea9..55b702562b02 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -726,13 +726,9 @@ def _forward_with_lora( return final_hidden_states def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): - # For MoE layers, tensor parallelism is typically not used - # Return weights unchanged return A def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): - # For MoE layers, tensor parallelism is typically not used - # Return weights unchanged return B From 1b8e359e073a7aece7d771699871b8464f57600b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 7 Feb 2026 18:04:00 -0500 Subject: [PATCH 096/150] remove unused code --- python/sglang/srt/lora/mem_pool.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 711a5769b3b5..ed99230af6db 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -575,24 +575,6 @@ def load_lora_weight_tensor( # Skip weight slicing if the weight is not present in the adapter continue - # Handle MoE modules (they contain dicts of per-expert tensors) - if isinstance(temp_A_buffer[target_module], dict): - # Slice each expert's weights individually - for expert_id in temp_A_buffer[target_module].keys(): - temp_A_buffer[target_module][expert_id] = ( - module.slice_lora_a_weights( - temp_A_buffer[target_module][expert_id], - self.tp_rank, - ) - ) - temp_B_buffer[target_module][expert_id] = ( - module.slice_lora_b_weights( - temp_B_buffer[target_module][expert_id], - self.tp_rank, - ) - ) - continue - # Handle standard modules temp_A_buffer[target_module] = module.slice_lora_a_weights( temp_A_buffer[target_module], self.tp_rank From bae8100a3320035d597b9359985ed102082f8328 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 7 Feb 2026 18:42:05 -0500 Subject: [PATCH 097/150] better check in mempool --- python/sglang/srt/lora/mem_pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index ed99230af6db..9c9f117740e7 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -588,7 +588,7 @@ def load_lora_weight_tensor( c = get_stacked_multiply(name) # TODO: delete this target_buffer = self.A_buffer[name][layer_id] - if isinstance(weights, dict): + if name in ["gate_up_proj_moe", "down_proj_moe"]: # MoE: multiple tensors per module (one per expert) for expert_id, expert_weight in weights.items(): # Buffer shape: [num_loras, num_experts, max_rank, hidden_dim] @@ -605,7 +605,7 @@ def load_lora_weight_tensor( for name, weights in temp_B_buffer.items(): target_buffer = self.B_buffer[name][layer_id] - if isinstance(weights, dict): + if name in ["gate_up_proj_moe", "down_proj_moe"]: # MoE: multiple tensors per module (one per expert) for expert_id, expert_weight in weights.items(): # Buffer shape: [num_loras, num_experts, intermediate_dim, max_rank] From 297abccd1e34a18e6d8e134a06513e2197594d36 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 8 Feb 2026 09:48:52 -0500 Subject: [PATCH 098/150] code quality --- python/sglang/srt/lora/mem_pool.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 9c9f117740e7..ab12bf855f13 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -162,22 +162,16 @@ def get_lora_A_shape( if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: input_dim = divide(input_dim, self.tp_size) - # Check if MoE module and return appropriate shape (the assumption is that down_proj and gate_up_proj are only used in MoE modules) if self.is_moe_module(module_name): num_experts = getattr( self.base_hf_config, "num_local_experts", getattr(self.base_hf_config, "num_experts", 0), ) - # Allocate all MoE buffers with the same maximum rank dimension - # to ensure consistent kernel compilation. The maximum stacking factor is 2. - max_rank_dim = ( - max_lora_dim * 2 - ) # Accommodate maximum stacking (gate_up_proj) return ( self.max_loras_per_batch, num_experts, - max_rank_dim, + max_lora_dim * c, input_dim, ) else: From 5b2c58547db4860d2919577d70fe7091f41e1707 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 8 Feb 2026 15:01:17 -0500 Subject: [PATCH 099/150] improve code quality --- python/sglang/srt/lora/mem_pool.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index ab12bf855f13..8d99a38683b6 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -163,11 +163,7 @@ def get_lora_A_shape( input_dim = divide(input_dim, self.tp_size) if self.is_moe_module(module_name): - num_experts = getattr( - self.base_hf_config, - "num_local_experts", - getattr(self.base_hf_config, "num_experts", 0), - ) + num_experts = self.base_model.config.num_experts return ( self.max_loras_per_batch, num_experts, @@ -216,11 +212,7 @@ def get_lora_B_shape( # Check if MoE module and return appropriate shape if self.is_moe_module(module_name): - num_experts = getattr( - self.base_hf_config, - "num_local_experts", - getattr(self.base_hf_config, "num_experts", 0), - ) + num_experts = self.base_model.config.num_experts return (self.max_loras_per_batch, num_experts, output_dim, max_lora_dim) else: return (self.max_loras_per_batch, output_dim, max_lora_dim) From 05a9ca8c980328f4daf35871ab41cf9ce29b3909 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 9 Feb 2026 10:25:09 -0500 Subject: [PATCH 100/150] remove unused code --- python/sglang/srt/model_executor/forward_batch_info.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e69af77c2718..bcd1fda10a02 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -310,8 +310,6 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): # For LoRA lora_ids: Optional[List[str]] = None - # Per-token LoRA adapter indices (expanded from lora_ids) - token_lora_indices: Optional[torch.Tensor] = None # For input embeddings input_embeds: Optional[torch.Tensor] = None From a329d015997f1f49a8116567b930b1ce7de65652 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 9 Feb 2026 13:16:24 -0500 Subject: [PATCH 101/150] remove unused code --- python/sglang/srt/lora/lora_moe_runners.py | 43 ---------------------- 1 file changed, 43 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 9363b2929265..0763cc926dc4 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -141,9 +141,6 @@ def run( Returns: TritonRunnerOutput with combined base + LoRA output """ - # If no LoRA, use base implementation - if lora_info is None: - return super().run(runner_input, quant_info, running_state) # Extract common variables hidden_states = runner_input.hidden_states @@ -195,46 +192,6 @@ def run( moe_sum_reduce_triton, ) - hidden_states = runner_input.hidden_states - topk_weights = runner_input.topk_weights - topk_ids = runner_input.topk_ids - sorted_token_ids = runner_input.sorted_token_ids - expert_ids = runner_input.expert_ids - num_tokens_post_padded = runner_input.num_tokens_post_padded - - w13 = quant_info.w13_weight - w2 = quant_info.w2_weight - b13 = quant_info.b13 - b2 = quant_info.b2 - a13_scale = quant_info.a13_scale - a2_scale = quant_info.a2_scale - w13_scale = quant_info.w13_scale - w2_scale = quant_info.w2_scale - w13_zp = quant_info.w13_zp - w2_zp = quant_info.w2_zp - block_shape = quant_info.block_shape - per_channel_quant = quant_info.per_channel_quant - use_fp8_w8a8 = quant_info.use_fp8_w8a8 - use_int8_w8a8 = quant_info.use_int8_w8a8 - use_int8_w8a16 = quant_info.use_int8_w8a16 - use_int4_w4a16 = quant_info.use_int4_w4a16 - - activation = self.config.activation - no_combine = self.config.no_combine - inplace = self.config.inplace - gemm1_alpha = self.config.gemm1_alpha - gemm1_limit = self.config.gemm1_clamp_limit - routed_scaling_factor = self.config.routed_scaling_factor - apply_router_weight_on_input = self.config.apply_router_weight_on_input - - assert self.config.is_gated, "Only gated MoEs are supported for Triton runner" - - M = hidden_states.shape[0] - E, N, _ = w13.shape - compute_type = ( - tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - ) - # ============================================================ # Stage 1: Gate/Up projection (base) # ============================================================ From d477548f9371ddd1508e656bd005b345676e975f Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 9 Feb 2026 15:34:46 -0500 Subject: [PATCH 102/150] add GDC support --- .../sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index fec8c8b24631..833a2dafdcdd 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -10,6 +10,8 @@ tensor_model_parallel_all_reduce, ) +from sglang.srt.utils.common import is_sm90_supported, is_blackwell_supported + # Import SGLang's standard PDL support detection @@ -98,7 +100,6 @@ def _fused_moe_lora_kernel( lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) - USE_GDC = False # TODO (Jonahcb): remove this if lora_id == -1: # Early exit for the no-lora case. @@ -243,9 +244,7 @@ def _fused_moe_lora_shrink( ) -> None: w1_lora_a_stacked = lora_a_stacked[0] - # TODO (Jonahcb): investigate why relying on is_arch_support_pdl() is causing crash inside kernel - # use_gdc = is_arch_support_pdl() - use_gdc = False + use_gdc = (is_sm90_supported() or is_blackwell_supported()) shrink_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, @@ -350,7 +349,7 @@ def _fused_moe_lora_expand( -1, a_intermediate_cache1.shape[3] ) - use_gdc = is_arch_support_pdl() + use_gdc = (is_sm90_supported() or is_blackwell_supported()) expand_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, From 137c9cb509e993ebdbc06574317386fab4e25dbe Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 10 Feb 2026 10:27:28 -0500 Subject: [PATCH 103/150] remove unused code --- .../csrc/moe/moe_lora_align_sum_kernel.cu | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu index e30bd60dcdae..8a5d2a263353 100644 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -441,21 +441,6 @@ __global__ void count_and_sort_expert_tokens_kernel( has_expert_map); } -template -__global__ void moe_sum_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., topk, d] - const int d) { - const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - scalar_t x = 0.0; -#pragma unroll - for (int k = 0; k < TOPK; ++k) { - x += __ldg(&input[token_idx * TOPK * d + k * d + idx]); - } - out[token_idx * d + idx] = x; - } -} template __global__ void moe_align_block_size_small_batch_expert_kernel( @@ -816,46 +801,6 @@ void batched_moe_align_block_size( num_tokens_post_pad.data_ptr()); } -void moe_sum( - torch::Tensor& input, // [num_tokens, topk, hidden_size] - torch::Tensor& output) // [num_tokens, hidden_size] -{ - const int hidden_size = input.size(-1); - const auto num_tokens = output.numel() / hidden_size; - const int topk = input.size(1); - - dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - switch (topk) { - case 2: - DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel - <<>>(output.data_ptr(), input.data_ptr(), hidden_size); - }); - break; - - case 3: - DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel - <<>>(output.data_ptr(), input.data_ptr(), hidden_size); - }); - break; - - case 4: - DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { - moe::moe_sum_kernel - <<>>(output.data_ptr(), input.data_ptr(), hidden_size); - }); - break; - - default: - at::sum_out(output, input, 1); - break; - } -} void moe_lora_align_block_size( torch::Tensor topk_ids, From 2940aa2d0eea5b016d33d248c7ba563b6b631011 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 10 Feb 2026 11:21:39 -0500 Subject: [PATCH 104/150] remove unused code --- .../csrc/moe/moe_lora_align_sum_kernel.cu | 282 ------------------ 1 file changed, 282 deletions(-) diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu index 8a5d2a263353..63030ea18dd4 100644 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu @@ -64,68 +64,10 @@ __host__ __device__ inline T round_to_next_multiple_of(T x, T y) { #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace moe { -namespace batched_moe_align_block_size { // Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. static constexpr int32_t num_threads = 1024; static constexpr int32_t num_blocks = 1; -__global__ void batched_moe_align_block_size_kernel( - int32_t const num_batches, - int32_t const max_tokens_per_batch, - int32_t const block_size, - int32_t const* __restrict__ batch_num_tokens, - int32_t* __restrict__ sorted_ids, - int32_t* __restrict__ block_ids, - int32_t* __restrict__ num_tokens_post_pad) { - // TODO(varun): This is a naive implementation. Could be optimized. - - size_t const batch_id = threadIdx.x; - size_t const stride = blockDim.x * gridDim.x; - int32_t const num_blocks_per_batch = CEILDIV(max_tokens_per_batch, block_size); - int32_t const sorted_ids_size = num_blocks_per_batch * num_batches * block_size; - int32_t const block_ids_size = sorted_ids_size / block_size; - int32_t const SENTINEL = num_batches * max_tokens_per_batch; // To denote invalid entries. - // Initialize sorted_ids - for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { - sorted_ids[i] = SENTINEL; - } - // Initialize expert_ids with -1 - for (size_t i = threadIdx.x; i < block_ids_size; i += stride) { - block_ids[i] = -1; - } - - int32_t b_num_tokens = 0; - if (batch_id < num_batches) { - b_num_tokens = batch_num_tokens[batch_id]; - } - int32_t const ceil_b_num_tokens = CEILDIV(b_num_tokens, block_size) * block_size; - - // Compute prefix sum over token counts per expert - using BlockScan = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - int cumsum_val; - BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val); - __syncthreads(); - - bool const is_last_batch = batch_id == (num_batches - 1); - if (is_last_batch) { - *num_tokens_post_pad = cumsum_val + ceil_b_num_tokens; - } - - if (batch_id < num_batches) { - int32_t const batch_offset = batch_id * max_tokens_per_batch; - for (size_t i = 0; i < b_num_tokens; ++i) { - sorted_ids[cumsum_val + i] = batch_offset + i; - } - - int32_t const block_start = cumsum_val / block_size; - int32_t const num_blocks = ceil_b_num_tokens / block_size; - for (size_t i = 0; i < num_blocks; ++i) { - block_ids[block_start + i] = batch_id; - } - } -} -} // namespace batched_moe_align_block_size template __device__ void _moe_align_block_size( @@ -379,100 +321,6 @@ __device__ void _count_and_sort_expert_tokens( } } -template -__global__ void moe_align_block_size_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t padded_num_experts, - int32_t experts_per_warp, - int32_t block_size, - size_t numel, - int32_t* __restrict__ cumsum, - int32_t max_num_tokens_padded, - int32_t topk_num, - bool has_expert_map) { - _moe_align_block_size( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - numel, - cumsum, - max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), - 0, - 0, - topk_num, - nullptr, - has_expert_map); -} - -template -__global__ void count_and_sort_expert_tokens_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, - size_t numel, - int32_t num_experts, - int32_t max_num_tokens_padded, - int32_t topk_num, - bool has_expert_map) { - _count_and_sort_expert_tokens( - topk_ids, - sorted_token_ids, - cumsum_buffer, - expert_map, - numel, - num_experts, - max_num_tokens_padded, - nullptr, - 0, - topk_num, - has_expert_map); -} - - -template -__global__ void moe_align_block_size_small_batch_expert_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t block_size, - size_t numel, - int32_t max_num_tokens_padded, - int32_t topk_num, - bool has_expert_map) { - _moe_align_block_size_small_batch_expert( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - block_size, - numel, - max_num_tokens_padded, - CEILDIV(max_num_tokens_padded, block_size), - 0, - 0, - topk_num, - nullptr, - has_expert_map); -} - template __global__ void moe_lora_align_block_size_kernel( scalar_t* __restrict__ topk_ids, @@ -671,136 +519,6 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( } // namespace moe -// taken from -// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc -void moe_align_block_size( - torch::Tensor topk_ids, - int64_t num_experts, - int64_t block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad, - std::optional maybe_expert_map) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - int experts_per_warp = WARP_SIZE; - int threads = 1024; - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - - // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); - auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - bool has_expert_map = maybe_expert_map.has_value(); - torch::Tensor expert_map; - if (has_expert_map) { - expert_map = maybe_expert_map.value(); - } else { - expert_map = torch::empty({0}, options_int); - } - - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `cumsum` tensors - bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t threads = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - auto small_batch_expert_kernel = moe::moe_align_block_size_small_batch_expert_kernel; - small_batch_expert_kernel<<<1, fill_threads + threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), - num_experts, - block_size, - topk_ids.numel(), - sorted_token_ids.size(0), - topk_ids.size(1), - has_expert_map); - } else { - torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int); - auto align_kernel = moe::moe_align_block_size_kernel; - - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); - size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); - - // launch two threadblocks - // blockIdx.x == 0: counting experts and aligning - // blockIdx.x == 1: filling sorted_token_ids - align_kernel<<<2, threads, shared_mem_size, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - expert_map.data_ptr(), - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - topk_ids.numel(), - cumsum_buffer.data_ptr(), - sorted_token_ids.size(0), - topk_ids.size(1), - has_expert_map); - - const int block_threads = std::min(256, (int)threads); - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - dim3 gridDims(1, actual_blocks); - - auto sort_kernel = moe::count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), - expert_map.data_ptr(), - topk_ids.numel(), - num_experts, - sorted_token_ids.size(0), - topk_ids.size(1), - has_expert_map); - } - }); -} - -void batched_moe_align_block_size( - int64_t max_tokens_per_batch, - int64_t block_size, - torch::Tensor const& batch_num_tokens, - torch::Tensor sorted_ids, - torch::Tensor batch_ids, - torch::Tensor num_tokens_post_pad) { - namespace batched_kernel = moe::batched_moe_align_block_size; - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int32_t const B = batch_num_tokens.size(0); - int32_t const num_blocks_per_batch = round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; - int32_t const num_blocks = num_blocks_per_batch * B; - int64_t const sorted_ids_size = num_blocks * block_size; - - TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size); - TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size); - TORCH_CHECK(num_tokens_post_pad.size(0) == 1); - TORCH_CHECK(B <= batched_kernel::num_threads); - - batched_kernel:: - batched_moe_align_block_size_kernel<<>>( - B, - max_tokens_per_batch, - block_size, - batch_num_tokens.data_ptr(), - sorted_ids.data_ptr(), - batch_ids.data_ptr(), - num_tokens_post_pad.data_ptr()); -} - void moe_lora_align_block_size( torch::Tensor topk_ids, From 9ac01d3c27eea5bfd7c1816630a1ecc35cfd3646 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 14 Feb 2026 09:12:46 -0500 Subject: [PATCH 105/150] move token sorting kernel to jit kernel folder --- python/sglang/srt/lora/lora_moe_runners.py | 2 +- sgl-kernel/csrc/common_extension.cc | 8 - .../csrc/moe/moe_lora_align_sum_kernel.cu | 662 ------------------ sgl-kernel/include/sgl_kernel_ops.h | 16 - sgl-kernel/python/sgl_kernel/__init__.py | 1 - sgl-kernel/python/sgl_kernel/moe.py | 34 - .../manual/lora/test_fused_moe_lora_kernel.py | 2 +- test/manual/lora/test_moe_lora_align_sum.py | 2 +- 8 files changed, 3 insertions(+), 724 deletions(-) delete mode 100644 sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 0763cc926dc4..32648fa73118 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -71,7 +71,7 @@ from sgl_kernel import ( # noqa: F401 moe_align_block_size as sgl_moe_align_block_size, ) - from sgl_kernel import moe_lora_align_block_size + from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size @dataclass diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 578a1455c70f..d0b6fcf80bf6 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -227,14 +227,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "pad_sorted_token_ids) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - m.def( - "moe_lora_align_block_size(Tensor topk_ids, Tensor seg_indptr, Tensor req_to_lora, " - "int num_experts, int block_size, int max_loras, int max_num_tokens_padded, " - "int max_num_m_blocks, Tensor! sorted_token_ids, Tensor! expert_ids, " - "Tensor! num_tokens_post_pad, Tensor adapter_enabled, Tensor lora_ids, " - "Tensor? maybe_expert_map) -> ()"); - m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); - m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float " "moe_softcapping, Tensor? correction_bias) -> ()"); diff --git a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu b/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu deleted file mode 100644 index 63030ea18dd4..000000000000 --- a/sgl-kernel/csrc/moe/moe_lora_align_sum_kernel.cu +++ /dev/null @@ -1,662 +0,0 @@ -// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu - -#include -#include -#include - -#include -#include - -// ================================================================ -// STANDALONE UTILS REPLACEMENT -// Insert this after your standard #include statements -// ================================================================ - -#ifndef WARP_SIZE -#define WARP_SIZE 32 -#endif - -// Used in batched_moe_align_block_size -template -__host__ __device__ inline T round_to_next_multiple_of(T x, T y) { - return ((x + y - 1) / y) * y; -} - -// Minimal Dispatch Macros to avoid compiling full utils -#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Int: { \ - using scalar_t = int32_t; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::Long: { \ - using scalar_t = int64_t; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ - TORCH_CHECK(false, #NAME " not implemented for ", toString(TYPE)); \ - } - -#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ - TORCH_CHECK(false, #NAME " not implemented for ", toString(TYPE)); \ - } -// ================================================================ - -#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) - -namespace moe { - -// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. -static constexpr int32_t num_threads = 1024; -static constexpr int32_t num_blocks = 1; - -template -__device__ void _moe_align_block_size( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t padded_num_experts, - int32_t experts_per_warp, - int32_t block_size, - size_t numel, - int32_t* __restrict__ cumsum, - int32_t max_num_tokens_padded, - int32_t max_num_m_blocks, - int32_t model_offset, - int32_t inactive_expert_id, - int32_t topk_num, - int32_t* token_mask, - bool has_expert_map) { - extern __shared__ int32_t shared_counts[]; - - // Compute input buffer offsets. Typically these will all be 0, except when - // using Multi LoRA. - int sorted_token_ids_offset = max_num_tokens_padded * model_offset; - int expert_ids_offset = max_num_m_blocks * model_offset; - int cumsum_offset = (num_experts + 1) * model_offset; - - // Use separate threadblocks to fill sorted_token_ids. - // This is safe since the current kernel does not use sorted_token_ids. - if (blockIdx.x % 2) { - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[sorted_token_ids_offset + it] = numel; - } - return; - } - - const int warp_id = threadIdx.x / WARP_SIZE; - const int my_expert_start = warp_id * experts_per_warp; - - for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < padded_num_experts) { - shared_counts[warp_id * experts_per_warp + i] = 0; - } - } - - __syncthreads(); - - const size_t tid = threadIdx.x; - const size_t stride = blockDim.x; - - for (size_t i = tid; i < numel; i += stride) { - int expert_id = topk_ids[i]; - if (expert_id >= num_experts) { - continue; - } - if (has_expert_map) { - expert_id = expert_map[expert_id]; - // filter invalid experts - if (expert_id == -1) continue; - } - int warp_idx = expert_id / experts_per_warp; - int expert_offset = expert_id % experts_per_warp; - int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], mask); - } - - __syncthreads(); - - // Compute prefix sum over token counts per expert - using BlockScan = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - - int expert_count = 0; - int expert_id = threadIdx.x; - if (expert_id < num_experts) { - int warp_idx = expert_id / experts_per_warp; - int expert_offset = expert_id % experts_per_warp; - expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; - expert_count = CEILDIV(expert_count, block_size) * block_size; - } - - int cumsum_val; - BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); - if (expert_id <= num_experts) { - cumsum[cumsum_offset + expert_id] = cumsum_val; - } - - if (expert_id == num_experts) { - total_tokens_post_pad[model_offset] = cumsum_val; - } - - __syncthreads(); - - if (threadIdx.x < num_experts) { - for (int i = cumsum[cumsum_offset + threadIdx.x]; i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { - expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; - } - } - - // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; - for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { - expert_ids[expert_ids_offset + i] = inactive_expert_id; - } -} - -template -__device__ void _moe_align_block_size_small_batch_expert( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t* __restrict__ expert_map, - int32_t num_experts, - int32_t block_size, - size_t numel, - int32_t max_num_tokens_padded, - int32_t max_num_m_blocks, - int32_t inactive_expert_id, - int32_t model_offset, - int32_t topk_num, - int32_t* token_mask, - bool has_expert_map) { - // Compute input buffer offsets. Typically these will all be 0, except when - // using Multi LoRA. - int sorted_token_ids_offset = max_num_tokens_padded * model_offset; - int expert_ids_offset = max_num_m_blocks * model_offset; - - // Use an additional group of threads to fill sorted_token_ids. - // Since the current kernel will use sorted_token_ids afterward, - // we fill sorted_token_ids within the same threadblock to make - // synchronization easier. - if (threadIdx.x < fill_threads) { - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) { - sorted_token_ids[sorted_token_ids_offset + it] = numel; - } - // Three __syncthreads() corresponding to the other threads - __syncthreads(); - __syncthreads(); - __syncthreads(); - return; - } - - const size_t tid = threadIdx.x - fill_threads; - const size_t stride = blockDim.x - fill_threads; - - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; - int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[(tid + 1) * num_experts + i] = 0; - } - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - if (has_expert_map) { - expert_id = expert_map[expert_id]; - // filter invalid expert - if (expert_id == -1) continue; - } - int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; - tokens_cnts[(tid + 1) * num_experts + expert_id] += mask; - } - - __syncthreads(); - - if (tid < num_experts) { - tokens_cnts[tid] = 0; - for (int i = 1; i <= stride; ++i) { - tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; - } - } - - __syncthreads(); - - if (tid == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size; - } - total_tokens_post_pad[model_offset] = static_cast(cumsum[num_experts]); - } - - __syncthreads(); - - if (tid < num_experts) { - for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { - expert_ids[expert_ids_offset + i / block_size] = tid; - } - } - - // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; - for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { - expert_ids[expert_ids_offset + i] = inactive_expert_id; - } - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - if (has_expert_map) { - expert_id = expert_map[expert_id]; - // filter invalid expert - if (expert_id == -1) continue; - } - int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; - - if (token_mask == nullptr || token_mask[i / topk_num]) { - sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; - ++tokens_cnts[tid * num_experts + expert_id]; - } - } -} - -template -__device__ void _count_and_sort_expert_tokens( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, - size_t numel, - int32_t num_experts, - int32_t max_num_tokens_padded, - int32_t* __restrict__ token_mask, - int32_t model_offset, - int32_t topk_num, - bool has_expert_map) { - const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.y; - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - if (expert_id >= num_experts) { - continue; - } - - if (has_expert_map) { - expert_id = expert_map[expert_id]; - // filter invalid experts - if (expert_id == -1) continue; - } - - if (token_mask == nullptr || token_mask[i / topk_num]) { - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); - sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i; - } - } -} - -template -__global__ void moe_lora_align_block_size_kernel( - scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ seg_indptr, - int32_t* __restrict__ req_to_lora, - int num_reqs, - int64_t block_size, - int32_t* __restrict__ expert_map, - int num_experts, - int max_loras, - size_t numel, - int max_num_tokens_padded, - int max_num_m_blocks, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int32_t topk_num, - int32_t* total_tokens_post_pad, - int32_t* adapter_enabled, - int32_t* __restrict__ cumsum, - int32_t experts_per_warp, - int32_t padded_num_experts, - int32_t* lora_ids, - int32_t* __restrict__ token_mask, - bool has_expert_map) { - int lora_idx = blockIdx.x / 2; - int lora_id = lora_ids[lora_idx]; - if (lora_id == -1 || adapter_enabled[lora_id] == 0) { - return; - } - - // Populate the token_mask based on the token-LoRA mapping - int num_tokens = numel / topk_num; - int lora_offset = lora_id * num_tokens; - - // 1. Parallel Clear (Reset mask to 0) - // All threads help clear the mask for this adapter - for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - token_mask[lora_offset + i] = 0; - } - - // Initialize output counter - if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; - } - - __syncthreads(); - - // 2. Segment-based Fill - // Iterate over requests. If a request matches this LoRA, fill its range. - for (int r = 0; r < num_reqs; ++r) { - if (req_to_lora[r] == lora_id) { - int start = seg_indptr[r]; - int end = seg_indptr[r + 1]; - - // Parallel Fill: All threads help mark this segment as "1" - for (int i = start + threadIdx.x; i < end; i += blockDim.x) { - token_mask[lora_offset + i] = 1; - } - } - } - - __syncthreads(); - - _moe_align_block_size( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - padded_num_experts, - experts_per_warp, - block_size, - numel, - cumsum, - max_num_tokens_padded, - max_num_m_blocks, - lora_id, - -1, - topk_num, - &token_mask[(lora_id * num_tokens)], - has_expert_map); -} - -template -__global__ void lora_count_and_sort_expert_tokens_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - int32_t* __restrict__ expert_map, - size_t numel, - int32_t num_experts, - int32_t max_num_tokens_padded, - int32_t topk_num, - int32_t* token_mask, - int32_t* lora_ids, - bool has_expert_map) { - int lora_idx = blockIdx.x; - int lora_id = lora_ids[lora_idx]; - if (lora_id == -1) { - return; - } - - int num_tokens = numel / topk_num; - - _count_and_sort_expert_tokens( - topk_ids, - sorted_token_ids, - cumsum_buffer, - expert_map, - numel, - num_experts, - max_num_tokens_padded, - &token_mask[(lora_id * num_tokens)], - lora_id, - topk_num, - has_expert_map); -} - -template -__global__ void moe_lora_align_block_size_small_batch_expert_kernel( - scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ seg_indptr, - int32_t* __restrict__ req_to_lora, - int num_reqs, - int64_t block_size, - int32_t* __restrict__ expert_map, - int num_experts, - int max_loras, - size_t numel, - int max_num_tokens_padded, - int max_num_m_blocks, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ expert_ids, - int topk_num, - int32_t* total_tokens_post_pad, - int32_t* adapter_enabled, - int32_t* lora_ids, - int32_t* token_mask, - bool has_expert_map) { - int lora_idx = blockIdx.x; - int lora_id = lora_ids[lora_idx]; - if (lora_id == -1 || adapter_enabled[lora_id] == 0) { - return; - } - - int num_tokens = numel / topk_num; - int lora_offset = lora_id * num_tokens; - - // 1. Parallel Clear (Reset mask to 0) - // All threads help clear the mask for this adapter - for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - token_mask[lora_offset + i] = 0; - } - - // Initialize output counter - if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; - } - - __syncthreads(); - - // 2. Segment-based Fill - // Iterate over requests. If a request matches this LoRA, fill its range. - for (int r = 0; r < num_reqs; ++r) { - if (req_to_lora[r] == lora_id) { - int start = seg_indptr[r]; - int end = seg_indptr[r + 1]; - - // Parallel Fill: All threads help mark this segment as "1" - for (int i = start + threadIdx.x; i < end; i += blockDim.x) { - token_mask[lora_offset + i] = 1; - } - } - } - - __syncthreads(); - - _moe_align_block_size_small_batch_expert( - topk_ids, - sorted_token_ids, - expert_ids, - total_tokens_post_pad, - expert_map, - num_experts, - block_size, - numel, - max_num_tokens_padded, - max_num_m_blocks, - -1, - lora_id, - topk_num, - &token_mask[(lora_id * num_tokens)], - has_expert_map); -} - -} // namespace moe - - -void moe_lora_align_block_size( - torch::Tensor topk_ids, - torch::Tensor seg_indptr, - torch::Tensor req_to_lora, - int64_t num_experts, - int64_t block_size, - int64_t max_loras, - int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, - torch::Tensor adapter_enabled, - torch::Tensor lora_ids, - std::optional maybe_expert_map) { - const int topk_num = topk_ids.size(1); - - TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); - - int device_max_shared_mem; - auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - - // BlockScan uses 1024 threads and assigns one thread per expert. - TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); - - auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor token_mask = torch::empty({max_loras * topk_ids.size(0)}, options_int); - bool has_expert_map = maybe_expert_map.has_value(); - torch::Tensor expert_map; - if (has_expert_map) { - expert_map = maybe_expert_map.value(); - } else { - expert_map = torch::empty({0}, options_int); - } - int num_reqs = seg_indptr.size(0) - 1; - - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { - bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); - - if (small_batch_expert_mode) { - const int32_t num_thread = max((int32_t)num_experts, 128); - const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); - if (shared_mem > device_max_shared_mem) { - TORCH_CHECK(false, "Shared memory usage exceeds device limit."); - } - - // threadIdx.x >= fill_threads: counting experts and aligning - // threadIdx.x < fill_threads: filling sorted_token_ids - constexpr int32_t fill_threads = 256; - - dim3 blockDim(num_thread + fill_threads); - auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem); - kernel<<>>( - topk_ids.data_ptr(), - seg_indptr.data_ptr(), - req_to_lora.data_ptr(), - num_reqs, - block_size, - expert_map.data_ptr(), - num_experts, - max_loras, - topk_ids.numel(), - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), - topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), - lora_ids.data_ptr(), - token_mask.data_ptr(), - has_expert_map); - } else { - int num_thread = 1024; - dim3 blockDim(num_thread); - size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); - - size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); - - // cumsum buffer - torch::Tensor cumsum = torch::zeros({max_loras * (num_experts + 1)}, options_int); - - auto align_kernel = moe::moe_lora_align_block_size_kernel; - - // launch two threadblocks for each lora - // blockIdx.x % 2 == 0: counting experts and aligning - // blockIdx.x % 2 == 1: filling sorted_token_ids - align_kernel<<>>( - topk_ids.data_ptr(), - seg_indptr.data_ptr(), - req_to_lora.data_ptr(), - num_reqs, - block_size, - expert_map.data_ptr(), - num_experts, - max_loras, - topk_ids.numel(), - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), - topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), - cumsum.data_ptr(), - WARP_SIZE, - padded_num_experts, - lora_ids.data_ptr(), - token_mask.data_ptr(), - has_expert_map); - - const int block_threads = std::min(256, (int)num_thread); - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - dim3 gridDims(max_loras, actual_blocks); - auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum.data_ptr(), - expert_map.data_ptr(), - topk_ids.numel(), - num_experts, - max_num_tokens_padded, - topk_num, - token_mask.data_ptr(), - lora_ids.data_ptr(), - has_expert_map); - } - }); -} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 2340b717e3d5..2eb0856aa2cb 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -324,22 +324,6 @@ void topk_sigmoid( bool renormalize, const c10::optional& correction_bias); -void moe_lora_align_block_size( - torch::Tensor topk_ids, - torch::Tensor seg_indptr, - torch::Tensor req_to_lora, - int64_t num_experts, - int64_t block_size, - int64_t max_loras, - int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, - torch::Tensor adapter_enabled, - torch::Tensor lora_ids, - std::optional maybe_expert_map); - void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor); void moe_sum(torch::Tensor& input, torch::Tensor& output); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 6289c68e4bd9..a24d3573be4a 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -99,7 +99,6 @@ kimi_k2_moe_fused_gate, moe_align_block_size, moe_fused_gate, - moe_lora_align_block_size, moe_sum, moe_sum_reduce, prepare_moe_input, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 31ee02381e89..d85e4b602751 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -25,40 +25,6 @@ def moe_align_block_size( ) -def moe_lora_align_block_size( - topk_ids, - seg_indptr, - req_to_lora, - num_experts, - block_size, - max_loras, - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids, - expert_ids, - num_tokens_post_pad, - adapter_enabled, - lora_ids, - maybe_expert_map=None, -): - torch.ops.sgl_kernel.moe_lora_align_block_size.default( - topk_ids, - seg_indptr, - req_to_lora, - num_experts, - block_size, - max_loras, - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids, - expert_ids, - num_tokens_post_pad, - adapter_enabled, - lora_ids, - maybe_expert_map, - ) - - def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/manual/lora/test_fused_moe_lora_kernel.py index 8d50ed95e203..af224968121b 100644 --- a/test/manual/lora/test_fused_moe_lora_kernel.py +++ b/test/manual/lora/test_fused_moe_lora_kernel.py @@ -7,7 +7,7 @@ # ============================================================================== # IMPORT PREBUILT KERNEL # ============================================================================== -from sgl_kernel import moe_lora_align_block_size +from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size from sglang.srt.lora.triton_ops import fused_moe_lora from sglang.srt.utils import set_random_seed diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/manual/lora/test_moe_lora_align_sum.py index 8c7a0100c7ca..505f2d3e6726 100644 --- a/test/manual/lora/test_moe_lora_align_sum.py +++ b/test/manual/lora/test_moe_lora_align_sum.py @@ -7,7 +7,7 @@ # --------------------------------------------------------- # IMPORT PREBUILT KERNEL # --------------------------------------------------------- -from sgl_kernel import moe_lora_align_block_size +from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size def round_up(x, base): From 0a576b323a663176b6abafbf7ccd365540ab3d78 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sat, 14 Feb 2026 09:13:44 -0500 Subject: [PATCH 106/150] move token sorting kernels to jit kernel --- .../csrc/lora/moe_lora_align_kernel.cu | 672 ++++++++++++++++++ python/sglang/jit_kernel/moe_lora_align.py | 61 ++ 2 files changed, 733 insertions(+) create mode 100644 python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu create mode 100644 python/sglang/jit_kernel/moe_lora_align.py diff --git a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu new file mode 100644 index 000000000000..caa272535e59 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu @@ -0,0 +1,672 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu + +#include +#include +#include +#include + +#include +#include + +// ================================================================ +// STANDALONE UTILS REPLACEMENT +// Insert this after your standard #include statements +// ================================================================ + +#ifndef WARP_SIZE +#define WARP_SIZE 32 +#endif + +// Used in batched_moe_align_block_size +template +__host__ __device__ inline T round_to_next_multiple_of(T x, T y) { + return ((x + y - 1) / y) * y; +} + +// ================================================================ + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +namespace moe { + +// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. +//static constexpr int32_t num_threads = 1024; not sure why these disappeared +//static constexpr int32_t num_blocks = 1; + +template +__device__ void _moe_align_block_size( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t padded_num_experts, + int32_t experts_per_warp, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, + int32_t model_offset, + int32_t inactive_expert_id, + int32_t topk_num, + int32_t* token_mask, + bool has_expert_map) { + extern __shared__ int32_t shared_counts[]; + + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + int cumsum_offset = (num_experts + 1) * model_offset; + + // Use separate threadblocks to fill sorted_token_ids. + // This is safe since the current kernel does not use sorted_token_ids. + if (blockIdx.x % 2) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + return; + } + + const int warp_id = threadIdx.x / WARP_SIZE; + const int my_expert_start = warp_id * experts_per_warp; + + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < padded_num_experts) { + shared_counts[warp_id * experts_per_warp + i] = 0; + } + } + + __syncthreads(); + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], mask); + } + + __syncthreads(); + + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int expert_count = 0; + int expert_id = threadIdx.x; + if (expert_id < num_experts) { + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + expert_count = CEILDIV(expert_count, block_size) * block_size; + } + + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); + if (expert_id <= num_experts) { + cumsum[cumsum_offset + expert_id] = cumsum_val; + } + + if (expert_id == num_experts) { + total_tokens_post_pad[model_offset] = cumsum_val; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[cumsum_offset + threadIdx.x]; i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } +} + +template +__device__ void _moe_align_block_size_small_batch_expert( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, + int32_t inactive_expert_id, + int32_t model_offset, + int32_t topk_num, + int32_t* token_mask, + bool has_expert_map) { + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + + // Use an additional group of threads to fill sorted_token_ids. + // Since the current kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + tokens_cnts[(tid + 1) * num_experts + expert_id] += mask; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size; + } + total_tokens_post_pad[model_offset] = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = tid; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; // this is a culprit for bug + + if (token_mask == nullptr || token_mask[i / topk_num]) { + sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } + } +} + +template +__device__ void _count_and_sort_expert_tokens( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t* __restrict__ token_mask, + int32_t model_offset, + int32_t topk_num, + bool has_expert_map) { + const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.y; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } + + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } + + if (token_mask == nullptr || token_mask[i / topk_num]) { + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i; + } + } +} + +template +__global__ void moe_lora_align_block_size_kernel( + scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ seg_indptr, + int32_t* __restrict__ req_to_lora, + int num_reqs, + int64_t block_size, + int32_t* __restrict__ expert_map, + int num_experts, + int max_loras, + size_t numel, + int max_num_tokens_padded, + int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t topk_num, + int32_t* total_tokens_post_pad, + int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, + int32_t experts_per_warp, + int32_t padded_num_experts, + int32_t* lora_ids, + int32_t* __restrict__ token_mask, + bool has_expert_map) { + int lora_idx = blockIdx.x / 2; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } + + // Populate the token_mask based on the token-LoRA mapping + int num_tokens = numel / topk_num; + int lora_offset = lora_id * num_tokens; + + // 1. Parallel Clear (Reset mask to 0) + // All threads help clear the mask for this adapter + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + token_mask[lora_offset + i] = 0; + } + + // Initialize output counter + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + } + + __syncthreads(); + + // 2. Segment-based Fill + // Iterate over requests. If a request matches this LoRA, fill its range. + for (int r = 0; r < num_reqs; ++r) { + if (req_to_lora[r] == lora_id) { + int start = seg_indptr[r]; + int end = seg_indptr[r + 1]; + + // Parallel Fill: All threads help mark this segment as "1" + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + token_mask[lora_offset + i] = 1; + } + } + } + + __syncthreads(); + + _moe_align_block_size( + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + padded_num_experts, + experts_per_warp, + block_size, + numel, + cumsum, + max_num_tokens_padded, + max_num_m_blocks, + lora_id, + num_experts, // inactive_expert_id padding + topk_num, + &token_mask[(lora_id * num_tokens)], + has_expert_map); +} + +template +__global__ void lora_count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, + size_t numel, + int32_t num_experts, + int32_t max_num_tokens_padded, + int32_t topk_num, + int32_t* token_mask, + int32_t* lora_ids, + bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1) { + return; + } + + int num_tokens = numel / topk_num; + + _count_and_sort_expert_tokens( + topk_ids, + sorted_token_ids, + cumsum_buffer, + expert_map, + numel, + num_experts, + max_num_tokens_padded, + &token_mask[(lora_id * num_tokens)], + lora_id, + topk_num, + has_expert_map); +} + +template +__global__ void moe_lora_align_block_size_small_batch_expert_kernel( + scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ seg_indptr, + int32_t* __restrict__ req_to_lora, + int num_reqs, + int64_t block_size, + int32_t* __restrict__ expert_map, + int num_experts, + int max_loras, + size_t numel, + int max_num_tokens_padded, + int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int topk_num, + int32_t* total_tokens_post_pad, + int32_t* adapter_enabled, + int32_t* lora_ids, + int32_t* token_mask, + bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } + + int num_tokens = numel / topk_num; + int lora_offset = lora_id * num_tokens; + + // 1. Parallel Clear (Reset mask to 0) + // All threads help clear the mask for this adapter + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + token_mask[lora_offset + i] = 0; + } + + // Initialize output counter + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + } + + __syncthreads(); + + // 2. Segment-based Fill + // Iterate over requests. If a request matches this LoRA, fill its range. + for (int r = 0; r < num_reqs; ++r) { + if (req_to_lora[r] == lora_id) { + int start = seg_indptr[r]; + int end = seg_indptr[r + 1]; + + // Parallel Fill: All threads help mark this segment as "1" + for (int i = start + threadIdx.x; i < end; i += blockDim.x) { + token_mask[lora_offset + i] = 1; + } + } + } + + __syncthreads(); + + _moe_align_block_size_small_batch_expert( + topk_ids, + sorted_token_ids, + expert_ids, + total_tokens_post_pad, + expert_map, + num_experts, + block_size, + numel, + max_num_tokens_padded, + max_num_m_blocks, + num_experts, // inactive_expert_id padding + lora_id, + topk_num, + &token_mask[(lora_id * num_tokens)], + has_expert_map); +} + +} // namespace moe + +namespace { + +#define DISPATCH_TVM_INTEGRAL_TYPES(tvm_dtype_code, tvm_dtype_bits, c_type, ...) \ + [&]() -> bool { \ + if (tvm_dtype_code == kDLInt && tvm_dtype_bits == 32) { \ + using c_type = int32_t; \ + return __VA_ARGS__(); \ + } \ + if (tvm_dtype_code == kDLInt && tvm_dtype_bits == 64) { \ + using c_type = int64_t; \ + return __VA_ARGS__(); \ + } \ + host::RuntimeCheck(false, "Unsupported data type. Only int32 and int64 are supported."); \ + return false; \ + }() + +struct MoeLoraAlignBlockSizeKernel { + static void run( + tvm::ffi::TensorView topk_ids, + tvm::ffi::TensorView seg_indptr, + tvm::ffi::TensorView req_to_lora, + int64_t num_experts, + int64_t block_size, + int64_t max_loras, + int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, + tvm::ffi::TensorView sorted_token_ids, + tvm::ffi::TensorView expert_ids, + tvm::ffi::TensorView num_tokens_post_pad, + tvm::ffi::TensorView adapter_enabled, + tvm::ffi::TensorView lora_ids, + tvm::ffi::Optional maybe_expert_map, + tvm::ffi::TensorView cumsum_buffer, + tvm::ffi::TensorView token_mask + ) { + using namespace host; + + const int topk_num = topk_ids.size(1); + + RuntimeCheck(block_size > 0, "block_size should be greater than 0. "); + + int device_max_shared_mem; + auto device = topk_ids.device(); + int dev_id = device.device_id; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id)); + const cudaStream_t stream = LaunchKernel::resolve_device(device); + + int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + // BlockScan uses 1024 threads and assigns one thread per expert. + RuntimeCheck(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); + + int32_t* token_mask_ptr = static_cast(token_mask.data_ptr()); + + bool has_expert_map = maybe_expert_map.has_value(); + int32_t* expert_map_ptr = nullptr; + if (has_expert_map) { + expert_map_ptr = static_cast(maybe_expert_map.value().data_ptr()); + } + int num_reqs = seg_indptr.size(0) - 1; + + auto topk_ids_dtype = topk_ids.dtype(); + + DISPATCH_TVM_INTEGRAL_TYPES(topk_ids_dtype.code, topk_ids_dtype.bits, scalar_t, [&] { + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = std::max((int32_t)num_experts, 128); + const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + RuntimeCheck(false, "Shared memory usage exceeds device limit."); + } + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel; + RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); + + LaunchKernel( + dim3(max_loras), + blockDim, + stream, + shared_mem + )( + kernel, + static_cast(topk_ids.data_ptr()), + static_cast(seg_indptr.data_ptr()), + static_cast(req_to_lora.data_ptr()), + num_reqs, + block_size, + expert_map_ptr, + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + static_cast(sorted_token_ids.data_ptr()), + static_cast(expert_ids.data_ptr()), + topk_num, + static_cast(num_tokens_post_pad.data_ptr()), + static_cast(adapter_enabled.data_ptr()), + static_cast(lora_ids.data_ptr()), + token_mask_ptr, + has_expert_map + ); + + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + auto align_kernel = moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + LaunchKernel( + dim3(max_loras * 2), + blockDim, + stream, + shared_mem_size + )( + align_kernel, + static_cast(topk_ids.data_ptr()), + static_cast(seg_indptr.data_ptr()), + static_cast(req_to_lora.data_ptr()), + num_reqs, + block_size, + expert_map_ptr, + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + static_cast(sorted_token_ids.data_ptr()), + static_cast(expert_ids.data_ptr()), + topk_num, + static_cast(num_tokens_post_pad.data_ptr()), + static_cast(adapter_enabled.data_ptr()), + static_cast(cumsum_buffer.data_ptr()), + WARP_SIZE, + padded_num_experts, + static_cast(lora_ids.data_ptr()), + token_mask_ptr, + has_expert_map + ); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel; + + LaunchKernel( + gridDims, + dim3(block_threads), + stream + )( + sort_kernel, + static_cast(topk_ids.data_ptr()), + static_cast(sorted_token_ids.data_ptr()), + static_cast(cumsum_buffer.data_ptr()), + expert_map_ptr, + topk_ids.numel(), + num_experts, + max_num_tokens_padded, + topk_num, + token_mask_ptr, + static_cast(lora_ids.data_ptr()), + has_expert_map + ); + } + return true; + }); + + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/moe_lora_align.py b/python/sglang/jit_kernel/moe_lora_align.py new file mode 100644 index 000000000000..9aede3e1ac32 --- /dev/null +++ b/python/sglang/jit_kernel/moe_lora_align.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_moe_align_module() -> Module: + return load_jit( + "moe_lora_align_block_size", + cuda_files=["lora/moe_lora_align_kernel.cu"], + cuda_wrappers=[ + ("moe_lora_align_block_size", "MoeLoraAlignBlockSizeKernel::run"), + ], + ) + +def moe_lora_align_block_size( + topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, + req_to_lora: torch.Tensor, + num_experts: int, + block_size: int, + max_loras: int, + max_num_tokens_padded: int, + max_num_m_blocks: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, + adapter_enabled: torch.Tensor, + lora_ids: torch.Tensor, + maybe_expert_map: Optional[torch.Tensor] = None, +) -> None: + module = _jit_moe_align_module() + + cumsum_buffer = torch.zeros(max_loras * (num_experts + 1), dtype=torch.int32, device=topk_ids.device) + token_mask = torch.empty((max_loras * topk_ids.shape[0],), dtype=torch.int32, device=topk_ids.device) + + module.moe_lora_align_block_size( + topk_ids, + seg_indptr, + req_to_lora, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + adapter_enabled, + lora_ids, + maybe_expert_map, + cumsum_buffer, + token_mask, + ) From d2e2b3580aa6a1f2aea44766eaff0020a368d7a4 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 14:36:18 -0500 Subject: [PATCH 107/150] Fix small error --- python/sglang/srt/lora/mem_pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 8d99a38683b6..0ff170f063d5 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -163,7 +163,7 @@ def get_lora_A_shape( input_dim = divide(input_dim, self.tp_size) if self.is_moe_module(module_name): - num_experts = self.base_model.config.num_experts + num_experts = base_model.config.num_experts return ( self.max_loras_per_batch, num_experts, @@ -212,7 +212,7 @@ def get_lora_B_shape( # Check if MoE module and return appropriate shape if self.is_moe_module(module_name): - num_experts = self.base_model.config.num_experts + num_experts = base_model.config.num_experts return (self.max_loras_per_batch, num_experts, output_dim, max_lora_dim) else: return (self.max_loras_per_batch, output_dim, max_lora_dim) From e246c10d15f2cd6160af407d21bdb2f8894c9cf7 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 15:47:22 -0500 Subject: [PATCH 108/150] small fix --- python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu index caa272535e59..ffae56d4e86a 100644 --- a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu +++ b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu @@ -363,7 +363,7 @@ __global__ void moe_lora_align_block_size_kernel( max_num_tokens_padded, max_num_m_blocks, lora_id, - num_experts, // inactive_expert_id padding + -1, // inactive_expert_id padding topk_num, &token_mask[(lora_id * num_tokens)], has_expert_map); @@ -474,7 +474,7 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( numel, max_num_tokens_padded, max_num_m_blocks, - num_experts, // inactive_expert_id padding + -1, // inactive_expert_id padding lora_id, topk_num, &token_mask[(lora_id * num_tokens)], From 8053b5f24548342c4030e55801b8fefd3578aa23 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 16:08:06 -0500 Subject: [PATCH 109/150] fix hf test --- .../lora/test_lora_hf_sgl_logprob_diff.py | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 6eed9f6723aa..9676b0caa2f5 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -461,6 +461,8 @@ def _run_comparison_test( disable_cuda_graph: bool = DISABLE_CUDA_GRAPH, lora_target_modules: Optional[List[str]] = LORA_TARGET_MODULES, tp_size: int = 1, + check_logprobs: bool = True, + output_match_threshold: Optional[float] = None, ): """ Run comparison test between SGLang and HuggingFace with LoRA. @@ -497,17 +499,27 @@ def _run_comparison_test( # Step 3: Compare log probabilities results, overall_stats = compare_logprobs(sglang_logprobs, hf_logprobs) - # Assert that all prompts pass the threshold - for result in results: - self.assertTrue( - result["prefill_logprob_match"], - f"Prefill logprob mismatch for prompt {result['prompt_idx']} " - f"(max_diff={result['prefill_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", - ) - self.assertTrue( - result["decode_logprob_match"], - f"Decode logprob mismatch for prompt {result['prompt_idx']} " - f"(max_diff={result['decode_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + if check_logprobs: + # Assert that all prompts pass the threshold + for result in results: + self.assertTrue( + result["prefill_logprob_match"], + f"Prefill logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['prefill_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + self.assertTrue( + result["decode_logprob_match"], + f"Decode logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['decode_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + # MoE's expert layers make logprob comparisons useless as the base MoE layers' output significantly differs between sglang and hf + if output_match_threshold is not None: + outputs_match_count = sum(r["outputs_match"] for r in results) + match_rate = outputs_match_count / len(results) + self.assertGreaterEqual( + match_rate, + output_match_threshold, + f"Output string match rate {match_rate:.2%} is below threshold {output_match_threshold:.2%}", ) print_section_header("Test completed successfully!") @@ -551,7 +563,16 @@ def test_moe_lora_logprob_comparison_basic(self): model_path = "Qwen/Qwen1.5-MoE-A2.7B" lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] - prompts = DEFAULT_TEST_PROMPTS[:2] # Use first 2 default prompts for basic test + + # Load prompts from JSON file + prompts_path = os.path.join( + os.path.dirname(__file__), + "prompts", + "sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json", + ) + with open(prompts_path, "r") as f: + prompts = json.load(f) + prompts = prompts[:2] self._run_comparison_test( model_path=model_path, @@ -559,6 +580,8 @@ def test_moe_lora_logprob_comparison_basic(self): prompts=prompts, max_new_tokens=32, lora_backend="triton", + check_logprobs=False, + output_match_threshold=0.9, ) def test_moe_lora_logprob_comparison_full(self): @@ -584,6 +607,8 @@ def test_moe_lora_logprob_comparison_full(self): prompts=prompts, max_new_tokens=32, lora_backend="triton", + check_logprobs=False, + output_match_threshold=0.9, ) From 5959187763169fb5caf085766f8adeaf6b000105 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 20:56:31 -0500 Subject: [PATCH 110/150] fix --- python/sglang/srt/lora/layers.py | 8 -------- .../srt/lora/triton_ops/fused_moe_lora_kernel.py | 1 - python/sglang/srt/server_args.py | 10 ++++++++++ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 55b702562b02..2f142e60b523 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -595,14 +595,6 @@ def __init__( base_layer: FusedMoE, lora_backend: BaseLoRABackend, ): - - # initialize triton_lora moe runner for batches with lora enabled - if lora_backend.name != "triton": - raise ValueError( - "FusedMoEWithLoRA only supports 'triton' backend. " - "Please set --lora-backend triton when using LoRA on MoE models." - ) - # initializes FusedMoE with its own moe_runner for base path super().__init__(base_layer, lora_backend) # LoRA tensors will be set by LoRAManager diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 833a2dafdcdd..36037f1efa23 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -3,7 +3,6 @@ import torch import triton import triton.language as tl -from sgl_kernel.utils import is_arch_support_pdl from sglang.srt.distributed import ( tensor_model_parallel_all_gather, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2c9f7926819d..816d8e5055ba 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -5381,6 +5381,16 @@ def check_lora_server_args(self): self.lora_target_modules.discard("embed_tokens") self.lora_target_modules.discard("lm_head") + # TODO: find creative solution to differentiate between MoE gate_proj, up_proj, and down_proj and non-MoE gate_proj, up_proj, and down_proj so we do not have to drop all. + if self.lora_backend != "triton": + logger.warning("Current LoRA backend does not support LoRA on MoE layers; " + "dropping 'gate_proj', 'up_proj', and 'down_proj from --lora-target-modules=all.", + "To apply LoRA to these, use --lora-backend triton." + ) + self.lora_target_modules.discard("gate_proj") + self.lora_target_modules.discard("up_proj") + self.lora_target_modules.discard("down_proj") + # Ensure sufficient information is provided for LoRA initialization. assert self.lora_paths or ( self.max_lora_rank and self.lora_target_modules From 224982aa7e339a83ba9235edd714c282516788a1 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 21:06:31 -0500 Subject: [PATCH 111/150] fix --- python/sglang/srt/lora/layers.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 2f142e60b523..e898060b3fb7 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -597,11 +597,6 @@ def __init__( ): # initializes FusedMoE with its own moe_runner for base path super().__init__(base_layer, lora_backend) - # LoRA tensors will be set by LoRAManager - self.gate_up_lora_a_weights = None - self.gate_up_lora_b_weights = None - self.down_lora_a_weights = None - self.down_lora_b_weights = None # initialize triton_lora moe runner for batches with lora enabled from sglang.srt.layers.moe.moe_runner.runner import MoeRunner @@ -727,7 +722,6 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend ) -> BaseLayerWithLoRA: - supported_layer_types = { # the order matters FusedMoE: FusedMoEWithLoRA, From 9f6aeec2a615c052475a0621075d827d870bfb69 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 21:27:02 -0500 Subject: [PATCH 112/150] remove unnecessary injection of max_lora_rank --- python/sglang/srt/lora/layers.py | 4 ++-- python/sglang/srt/lora/lora_manager.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index e898060b3fb7..1aa4a3514bac 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -642,8 +642,8 @@ def _get_lora_info(self): batch_info = self.lora_backend.batch_info lora_ranks = batch_info.lora_ranks # [num_loras] - # Compute max LoRA rank from current batch ranks - max_lora_rank = self.lora_backend.max_lora_rank + # max_lora_rank + max_lora_rank = self.gate_up_lora_a_weights.shape[2] # Create adapter_enabled tensor for the current batch # Only enable LoRA adapters that are actually used in this batch diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 5101162a1ff8..052667facb70 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -386,9 +386,6 @@ def init_state( target_modules=target_modules, ) - # Inject max_lora_rank into backend - self.lora_backend.max_lora_rank = self.max_lora_rank - self.init_lora_modules() self.init_memory_pool() self.update_lora_info() From 0b440e3005dd6e3ee379c6f5e35167a7be7c29a5 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 21:31:15 -0500 Subject: [PATCH 113/150] Revert "remove unnecessary injection of max_lora_rank" This reverts commit 9f6aeec2a615c052475a0621075d827d870bfb69. --- python/sglang/srt/lora/layers.py | 4 ++-- python/sglang/srt/lora/lora_manager.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 1aa4a3514bac..e898060b3fb7 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -642,8 +642,8 @@ def _get_lora_info(self): batch_info = self.lora_backend.batch_info lora_ranks = batch_info.lora_ranks # [num_loras] - # max_lora_rank - max_lora_rank = self.gate_up_lora_a_weights.shape[2] + # Compute max LoRA rank from current batch ranks + max_lora_rank = self.lora_backend.max_lora_rank # Create adapter_enabled tensor for the current batch # Only enable LoRA adapters that are actually used in this batch diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 052667facb70..5101162a1ff8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -386,6 +386,9 @@ def init_state( target_modules=target_modules, ) + # Inject max_lora_rank into backend + self.lora_backend.max_lora_rank = self.max_lora_rank + self.init_lora_modules() self.init_memory_pool() self.update_lora_info() From 68ea9c943580395d2afe5cb886cc19980d7b106c Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 15 Feb 2026 22:14:19 -0500 Subject: [PATCH 114/150] fix max lora ranks calc --- python/sglang/srt/lora/layers.py | 3 +-- python/sglang/srt/lora/lora_manager.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index e898060b3fb7..7a7dd5ab3eb5 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -642,8 +642,7 @@ def _get_lora_info(self): batch_info = self.lora_backend.batch_info lora_ranks = batch_info.lora_ranks # [num_loras] - # Compute max LoRA rank from current batch ranks - max_lora_rank = self.lora_backend.max_lora_rank + max_lora_rank = self.down_lora_a_weights.shape[2] # Create adapter_enabled tensor for the current batch # Only enable LoRA adapters that are actually used in this batch diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 5101162a1ff8..a0db8c5336a8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -386,8 +386,6 @@ def init_state( target_modules=target_modules, ) - # Inject max_lora_rank into backend - self.lora_backend.max_lora_rank = self.max_lora_rank self.init_lora_modules() self.init_memory_pool() From bf5448ac6ef334d2ebf648d5cde7b0d53ea00fc4 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 16 Feb 2026 10:01:59 -0500 Subject: [PATCH 115/150] lint and fix dropping lora modules issue --- .../csrc/lora/moe_lora_align_kernel.cu | 185 ++++++++---------- python/sglang/jit_kernel/moe_lora_align.py | 9 +- python/sglang/srt/lora/lora_manager.py | 9 +- python/sglang/srt/lora/lora_moe_runners.py | 11 +- .../lora/triton_ops/fused_moe_lora_kernel.py | 8 +- python/sglang/srt/server_args.py | 11 +- .../manual/lora/test_fused_moe_lora_kernel.py | 1 - 7 files changed, 107 insertions(+), 127 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu index ffae56d4e86a..01686bd05964 100644 --- a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu +++ b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu @@ -2,11 +2,13 @@ #include #include + #include + +#include #include #include -#include // ================================================================ // STANDALONE UTILS REPLACEMENT @@ -30,8 +32,8 @@ __host__ __device__ inline T round_to_next_multiple_of(T x, T y) { namespace moe { // Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. -//static constexpr int32_t num_threads = 1024; not sure why these disappeared -//static constexpr int32_t num_blocks = 1; +// static constexpr int32_t num_threads = 1024; not sure why these disappeared +// static constexpr int32_t num_blocks = 1; template __device__ void _moe_align_block_size( @@ -241,7 +243,8 @@ __device__ void _moe_align_block_size_small_batch_expert( // filter invalid expert if (expert_id == -1) continue; } - int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; // this is a culprit for bug + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; // this is a culprit for bug if (token_mask == nullptr || token_mask[i / topk_num]) { sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; @@ -363,7 +366,7 @@ __global__ void moe_lora_align_block_size_kernel( max_num_tokens_padded, max_num_m_blocks, lora_id, - -1, // inactive_expert_id padding + -1, // inactive_expert_id padding topk_num, &token_mask[(lora_id * num_tokens)], has_expert_map); @@ -474,7 +477,7 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( numel, max_num_tokens_padded, max_num_m_blocks, - -1, // inactive_expert_id padding + -1, // inactive_expert_id padding lora_id, topk_num, &token_mask[(lora_id * num_tokens)], @@ -485,23 +488,23 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel( namespace { -#define DISPATCH_TVM_INTEGRAL_TYPES(tvm_dtype_code, tvm_dtype_bits, c_type, ...) \ - [&]() -> bool { \ - if (tvm_dtype_code == kDLInt && tvm_dtype_bits == 32) { \ - using c_type = int32_t; \ - return __VA_ARGS__(); \ - } \ - if (tvm_dtype_code == kDLInt && tvm_dtype_bits == 64) { \ - using c_type = int64_t; \ - return __VA_ARGS__(); \ - } \ - host::RuntimeCheck(false, "Unsupported data type. Only int32 and int64 are supported."); \ - return false; \ +#define DISPATCH_TVM_INTEGRAL_TYPES(tvm_dtype_code, tvm_dtype_bits, c_type, ...) \ + [&]() -> bool { \ + if (tvm_dtype_code == kDLInt && tvm_dtype_bits == 32) { \ + using c_type = int32_t; \ + return __VA_ARGS__(); \ + } \ + if (tvm_dtype_code == kDLInt && tvm_dtype_bits == 64) { \ + using c_type = int64_t; \ + return __VA_ARGS__(); \ + } \ + host::RuntimeCheck(false, "Unsupported data type. Only int32 and int64 are supported."); \ + return false; \ }() struct MoeLoraAlignBlockSizeKernel { - static void run( - tvm::ffi::TensorView topk_ids, + static void + run(tvm::ffi::TensorView topk_ids, tvm::ffi::TensorView seg_indptr, tvm::ffi::TensorView req_to_lora, int64_t num_experts, @@ -516,8 +519,7 @@ struct MoeLoraAlignBlockSizeKernel { tvm::ffi::TensorView lora_ids, tvm::ffi::Optional maybe_expert_map, tvm::ffi::TensorView cumsum_buffer, - tvm::ffi::TensorView token_mask - ) { + tvm::ffi::TensorView token_mask) { using namespace host; const int topk_num = topk_ids.size(1); @@ -551,7 +553,8 @@ struct MoeLoraAlignBlockSizeKernel { if (small_batch_expert_mode) { const int32_t num_thread = std::max((int32_t)num_experts, 128); - const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); + const int32_t shared_mem = + (num_thread + 1) * num_experts * sizeof(int32_t) + (num_experts + 1) * sizeof(int32_t); if (shared_mem > device_max_shared_mem) { RuntimeCheck(false, "Shared memory usage exceeds device limit."); } @@ -564,33 +567,27 @@ struct MoeLoraAlignBlockSizeKernel { auto kernel = moe::moe_lora_align_block_size_small_batch_expert_kernel; RuntimeDeviceCheck(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); - LaunchKernel( - dim3(max_loras), - blockDim, - stream, - shared_mem - )( - kernel, - static_cast(topk_ids.data_ptr()), - static_cast(seg_indptr.data_ptr()), - static_cast(req_to_lora.data_ptr()), - num_reqs, - block_size, - expert_map_ptr, - num_experts, - max_loras, - topk_ids.numel(), - max_num_tokens_padded, - max_num_m_blocks, - static_cast(sorted_token_ids.data_ptr()), - static_cast(expert_ids.data_ptr()), - topk_num, - static_cast(num_tokens_post_pad.data_ptr()), - static_cast(adapter_enabled.data_ptr()), - static_cast(lora_ids.data_ptr()), - token_mask_ptr, - has_expert_map - ); + LaunchKernel(dim3(max_loras), blockDim, stream, shared_mem)( + kernel, + static_cast(topk_ids.data_ptr()), + static_cast(seg_indptr.data_ptr()), + static_cast(req_to_lora.data_ptr()), + num_reqs, + block_size, + expert_map_ptr, + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + static_cast(sorted_token_ids.data_ptr()), + static_cast(expert_ids.data_ptr()), + topk_num, + static_cast(num_tokens_post_pad.data_ptr()), + static_cast(adapter_enabled.data_ptr()), + static_cast(lora_ids.data_ptr()), + token_mask_ptr, + has_expert_map); } else { int num_thread = 1024; @@ -604,36 +601,30 @@ struct MoeLoraAlignBlockSizeKernel { // launch two threadblocks for each lora // blockIdx.x % 2 == 0: counting experts and aligning // blockIdx.x % 2 == 1: filling sorted_token_ids - LaunchKernel( - dim3(max_loras * 2), - blockDim, - stream, - shared_mem_size - )( - align_kernel, - static_cast(topk_ids.data_ptr()), - static_cast(seg_indptr.data_ptr()), - static_cast(req_to_lora.data_ptr()), - num_reqs, - block_size, - expert_map_ptr, - num_experts, - max_loras, - topk_ids.numel(), - max_num_tokens_padded, - max_num_m_blocks, - static_cast(sorted_token_ids.data_ptr()), - static_cast(expert_ids.data_ptr()), - topk_num, - static_cast(num_tokens_post_pad.data_ptr()), - static_cast(adapter_enabled.data_ptr()), - static_cast(cumsum_buffer.data_ptr()), - WARP_SIZE, - padded_num_experts, - static_cast(lora_ids.data_ptr()), - token_mask_ptr, - has_expert_map - ); + LaunchKernel(dim3(max_loras * 2), blockDim, stream, shared_mem_size)( + align_kernel, + static_cast(topk_ids.data_ptr()), + static_cast(seg_indptr.data_ptr()), + static_cast(req_to_lora.data_ptr()), + num_reqs, + block_size, + expert_map_ptr, + num_experts, + max_loras, + topk_ids.numel(), + max_num_tokens_padded, + max_num_m_blocks, + static_cast(sorted_token_ids.data_ptr()), + static_cast(expert_ids.data_ptr()), + topk_num, + static_cast(num_tokens_post_pad.data_ptr()), + static_cast(adapter_enabled.data_ptr()), + static_cast(cumsum_buffer.data_ptr()), + WARP_SIZE, + padded_num_experts, + static_cast(lora_ids.data_ptr()), + token_mask_ptr, + has_expert_map); const int block_threads = std::min(256, (int)num_thread); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -644,29 +635,23 @@ struct MoeLoraAlignBlockSizeKernel { dim3 gridDims(max_loras, actual_blocks); auto sort_kernel = moe::lora_count_and_sort_expert_tokens_kernel; - LaunchKernel( - gridDims, - dim3(block_threads), - stream - )( - sort_kernel, - static_cast(topk_ids.data_ptr()), - static_cast(sorted_token_ids.data_ptr()), - static_cast(cumsum_buffer.data_ptr()), - expert_map_ptr, - topk_ids.numel(), - num_experts, - max_num_tokens_padded, - topk_num, - token_mask_ptr, - static_cast(lora_ids.data_ptr()), - has_expert_map - ); + LaunchKernel(gridDims, dim3(block_threads), stream)( + sort_kernel, + static_cast(topk_ids.data_ptr()), + static_cast(sorted_token_ids.data_ptr()), + static_cast(cumsum_buffer.data_ptr()), + expert_map_ptr, + topk_ids.numel(), + num_experts, + max_num_tokens_padded, + topk_num, + token_mask_ptr, + static_cast(lora_ids.data_ptr()), + has_expert_map); } return true; }); - } }; -} // namespace +} // namespace diff --git a/python/sglang/jit_kernel/moe_lora_align.py b/python/sglang/jit_kernel/moe_lora_align.py index 9aede3e1ac32..d53978ec297e 100644 --- a/python/sglang/jit_kernel/moe_lora_align.py +++ b/python/sglang/jit_kernel/moe_lora_align.py @@ -20,6 +20,7 @@ def _jit_moe_align_module() -> Module: ], ) + def moe_lora_align_block_size( topk_ids: torch.Tensor, seg_indptr: torch.Tensor, @@ -38,8 +39,12 @@ def moe_lora_align_block_size( ) -> None: module = _jit_moe_align_module() - cumsum_buffer = torch.zeros(max_loras * (num_experts + 1), dtype=torch.int32, device=topk_ids.device) - token_mask = torch.empty((max_loras * topk_ids.shape[0],), dtype=torch.int32, device=topk_ids.device) + cumsum_buffer = torch.zeros( + max_loras * (num_experts + 1), dtype=torch.int32, device=topk_ids.device + ) + token_mask = torch.empty( + (max_loras * topk_ids.shape[0],), dtype=torch.int32, device=topk_ids.device + ) module.moe_lora_align_block_size( topk_ids, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index a0db8c5336a8..bd43adf45b0c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -386,7 +386,6 @@ def init_state( target_modules=target_modules, ) - self.init_lora_modules() self.init_memory_pool() self.update_lora_info() @@ -621,6 +620,14 @@ def init_lora_modules(self): if isinstance(module, FusedMoE) and all( x in self.target_modules for x in ["gate_up_proj", "down_proj"] ): + + if self.lora_backend.name != "triton": + logger.warning( + "Current LoRA backend does not support LoRA on MoE layers; " + "skipping MoE layer." + ) + continue + layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 32648fa73118..e4984a5122f1 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -52,15 +52,7 @@ from sgl_kernel import gelu_and_mul, silu_and_mul if _is_hip: - if _use_aiter: - try: - from aiter import moe_sum - except ImportError: - raise ImportError( - "aiter is required when SGLANG_USE_AITER is set to True" - ) - else: - from vllm import _custom_ops as vllm_ops # moe_sum + from vllm import _custom_ops as vllm_ops # moe_sum elif _is_cpu and _is_cpu_amx_available: pass elif _is_xpu: @@ -71,6 +63,7 @@ from sgl_kernel import ( # noqa: F401 moe_align_block_size as sgl_moe_align_block_size, ) + from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 36037f1efa23..75c933bb7197 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -8,8 +8,7 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) - -from sglang.srt.utils.common import is_sm90_supported, is_blackwell_supported +from sglang.srt.utils.common import is_blackwell_supported, is_sm90_supported # Import SGLang's standard PDL support detection @@ -99,7 +98,6 @@ def _fused_moe_lora_kernel( lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) - if lora_id == -1: # Early exit for the no-lora case. return @@ -243,7 +241,7 @@ def _fused_moe_lora_shrink( ) -> None: w1_lora_a_stacked = lora_a_stacked[0] - use_gdc = (is_sm90_supported() or is_blackwell_supported()) + use_gdc = is_sm90_supported() or is_blackwell_supported() shrink_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, @@ -348,7 +346,7 @@ def _fused_moe_lora_expand( -1, a_intermediate_cache1.shape[3] ) - use_gdc = (is_sm90_supported() or is_blackwell_supported()) + use_gdc = is_sm90_supported() or is_blackwell_supported() expand_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 816d8e5055ba..67b4e70b3c2e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -5381,15 +5381,8 @@ def check_lora_server_args(self): self.lora_target_modules.discard("embed_tokens") self.lora_target_modules.discard("lm_head") - # TODO: find creative solution to differentiate between MoE gate_proj, up_proj, and down_proj and non-MoE gate_proj, up_proj, and down_proj so we do not have to drop all. - if self.lora_backend != "triton": - logger.warning("Current LoRA backend does not support LoRA on MoE layers; " - "dropping 'gate_proj', 'up_proj', and 'down_proj from --lora-target-modules=all.", - "To apply LoRA to these, use --lora-backend triton." - ) - self.lora_target_modules.discard("gate_proj") - self.lora_target_modules.discard("up_proj") - self.lora_target_modules.discard("down_proj") + # TODO: find creative solution to differentiate between MoE gate_proj, up_proj, and down_proj and non-MoE gate_proj, up_proj, and down_proj here so we do not have to do + # it later in LoRA Manager # Ensure sufficient information is provided for LoRA initialization. assert self.lora_paths or ( diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/manual/lora/test_fused_moe_lora_kernel.py index af224968121b..67fc785ac018 100644 --- a/test/manual/lora/test_fused_moe_lora_kernel.py +++ b/test/manual/lora/test_fused_moe_lora_kernel.py @@ -8,7 +8,6 @@ # IMPORT PREBUILT KERNEL # ============================================================================== from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size - from sglang.srt.lora.triton_ops import fused_moe_lora from sglang.srt.utils import set_random_seed From de6ff7bf580d9483fce5b7610f0d3ad1d8c2e0e1 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 16 Feb 2026 10:23:10 -0500 Subject: [PATCH 116/150] add prompts back --- ...glang_lora_moe_test_qwen1.5-MoE-A2.7B.json | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json diff --git a/test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json b/test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json new file mode 100644 index 000000000000..cbfcb2403395 --- /dev/null +++ b/test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json @@ -0,0 +1,62 @@ +[ + "Reverse the following security code: 0HRUP0A ->", + "Write a one-sentence story about a wild tree at the cafe.", + "Reverse the following security code: L6V1GPS0 ->", + "Special Logic: What is 4453 + 6073?", + "Special Logic: What is 7918 + 560?", + "Reverse the following security code: ZQ582JQ ->", + "Write a one-sentence story about a slow cat in a simulation.", + "Write a one-sentence story about a dusty dragon in a castle.", + "Reverse the following security code: FBDJ4T ->", + "Write a one-sentence story about a dusty dragon on the moon.", + "Write a one-sentence story about a sad tree in a dungeon.", + "Reverse the following security code: JTKC8ZV ->", + "Reverse the following security code: WRNS9RA6 ->", + "Write a one-sentence story about a calm ninja in the ocean.", + "Write a one-sentence story about a glowing fairy in Paris.", + "Special Logic: What is 6200 + 7656?", + "Reverse the following security code: KRONNFW ->", + "Reverse the following security code: 439VS65 ->", + "Special Logic: What is 5826 + 6255?", + "Special Logic: What is 1406 + 6905?", + "Special Logic: What is 8751 + 3491?", + "Write a one-sentence story about a shiny robot in the jungle.", + "Reverse the following security code: SNRXGWX ->", + "Reverse the following security code: MV77BRP ->", + "Write a one-sentence story about a golden toaster on a cloud.", + "Special Logic: What is 5286 + 5653?", + "Write a one-sentence story about a slow detective in a spaceship.", + "Special Logic: What is 386 + 9948?", + "Special Logic: What is 7781 + 3605?", + "Reverse the following security code: T3N4AKNH ->", + "Reverse the following security code: QTAB86S ->", + "Special Logic: What is 7800 + 9453?", + "Write a one-sentence story about a brave cowboy in a time machine.", + "Special Logic: What is 2025 + 9721?", + "Write a one-sentence story about a dusty samurai under the bed.", + "Special Logic: What is 7626 + 7437?", + "Reverse the following security code: 2ITR7MU1 ->", + "Reverse the following security code: X9BBSM ->", + "Write a one-sentence story about a silent queen in a simulation.", + "Reverse the following security code: S2WA4S ->", + "Special Logic: What is 9965 + 5211?", + "Special Logic: What is 6975 + 7239?", + "Write a one-sentence story about a brave detective on Mars.", + "Write a one-sentence story about a angry detective in the jungle.", + "Reverse the following security code: 6M0PLGR ->", + "Write a one-sentence story about a sad tree at school.", + "Special Logic: What is 7309 + 4487?", + "Special Logic: What is 3585 + 2451?", + "Write a one-sentence story about a broken samurai under the bed.", + "Reverse the following security code: BSJDXAG ->", + "Special Logic: What is 7556 + 6234?", + "Reverse the following security code: LDD0P48 ->", + "Reverse the following security code: IWCTORL ->", + "Special Logic: What is 8547 + 9401?", + "Special Logic: What is 4250 + 4220?", + "Reverse the following security code: A6VTFB ->", + "Write a one-sentence story about a heavy pilot in a time machine.", + "Write a one-sentence story about a heavy ghost under the bed.", + "Write a one-sentence story about a dark zombie in a spaceship.", + "Special Logic: What is 6542 + 1558?" +] From 0bd02ce132e256e7d98df7a45737cec80bf39398 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sat, 21 Feb 2026 23:12:44 +0000 Subject: [PATCH 117/150] modify some ci tests --- .../csrc/lora/moe_lora_align_kernel.cu | 2 +- .../lora/triton_ops/fused_moe_lora_kernel.py | 2 +- ...glang_lora_moe_test_qwen1.5-MoE-A2.7B.json | 62 ------------------- .../lora/test_fused_moe_lora_kernel.py | 5 +- .../lora/test_lora_hf_sgl_logprob_diff.py | 48 +++++++------- .../lora/test_lora_moe_runner.py | 3 + .../lora/test_moe_lora_align_sum.py | 5 +- 7 files changed, 38 insertions(+), 89 deletions(-) delete mode 100644 test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json rename test/{manual => registered}/lora/test_fused_moe_lora_kernel.py (97%) rename test/{manual => registered}/lora/test_lora_moe_runner.py (99%) rename test/{manual => registered}/lora/test_moe_lora_align_sum.py (95%) diff --git a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu index 01686bd05964..33c9ac61e79a 100644 --- a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu +++ b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu @@ -1,4 +1,4 @@ -// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu +// Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu, will optimize in future refactor #include #include diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 75c933bb7197..627ae72ca924 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +# Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/vllm/lora/ops/triton_ops/fused_moe_lora_op.py, will optimize in future refactor import torch import triton diff --git a/test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json b/test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json deleted file mode 100644 index cbfcb2403395..000000000000 --- a/test/registered/lora/prompts/sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json +++ /dev/null @@ -1,62 +0,0 @@ -[ - "Reverse the following security code: 0HRUP0A ->", - "Write a one-sentence story about a wild tree at the cafe.", - "Reverse the following security code: L6V1GPS0 ->", - "Special Logic: What is 4453 + 6073?", - "Special Logic: What is 7918 + 560?", - "Reverse the following security code: ZQ582JQ ->", - "Write a one-sentence story about a slow cat in a simulation.", - "Write a one-sentence story about a dusty dragon in a castle.", - "Reverse the following security code: FBDJ4T ->", - "Write a one-sentence story about a dusty dragon on the moon.", - "Write a one-sentence story about a sad tree in a dungeon.", - "Reverse the following security code: JTKC8ZV ->", - "Reverse the following security code: WRNS9RA6 ->", - "Write a one-sentence story about a calm ninja in the ocean.", - "Write a one-sentence story about a glowing fairy in Paris.", - "Special Logic: What is 6200 + 7656?", - "Reverse the following security code: KRONNFW ->", - "Reverse the following security code: 439VS65 ->", - "Special Logic: What is 5826 + 6255?", - "Special Logic: What is 1406 + 6905?", - "Special Logic: What is 8751 + 3491?", - "Write a one-sentence story about a shiny robot in the jungle.", - "Reverse the following security code: SNRXGWX ->", - "Reverse the following security code: MV77BRP ->", - "Write a one-sentence story about a golden toaster on a cloud.", - "Special Logic: What is 5286 + 5653?", - "Write a one-sentence story about a slow detective in a spaceship.", - "Special Logic: What is 386 + 9948?", - "Special Logic: What is 7781 + 3605?", - "Reverse the following security code: T3N4AKNH ->", - "Reverse the following security code: QTAB86S ->", - "Special Logic: What is 7800 + 9453?", - "Write a one-sentence story about a brave cowboy in a time machine.", - "Special Logic: What is 2025 + 9721?", - "Write a one-sentence story about a dusty samurai under the bed.", - "Special Logic: What is 7626 + 7437?", - "Reverse the following security code: 2ITR7MU1 ->", - "Reverse the following security code: X9BBSM ->", - "Write a one-sentence story about a silent queen in a simulation.", - "Reverse the following security code: S2WA4S ->", - "Special Logic: What is 9965 + 5211?", - "Special Logic: What is 6975 + 7239?", - "Write a one-sentence story about a brave detective on Mars.", - "Write a one-sentence story about a angry detective in the jungle.", - "Reverse the following security code: 6M0PLGR ->", - "Write a one-sentence story about a sad tree at school.", - "Special Logic: What is 7309 + 4487?", - "Special Logic: What is 3585 + 2451?", - "Write a one-sentence story about a broken samurai under the bed.", - "Reverse the following security code: BSJDXAG ->", - "Special Logic: What is 7556 + 6234?", - "Reverse the following security code: LDD0P48 ->", - "Reverse the following security code: IWCTORL ->", - "Special Logic: What is 8547 + 9401?", - "Special Logic: What is 4250 + 4220?", - "Reverse the following security code: A6VTFB ->", - "Write a one-sentence story about a heavy pilot in a time machine.", - "Write a one-sentence story about a heavy ghost under the bed.", - "Write a one-sentence story about a dark zombie in a spaceship.", - "Special Logic: What is 6542 + 1558?" -] diff --git a/test/manual/lora/test_fused_moe_lora_kernel.py b/test/registered/lora/test_fused_moe_lora_kernel.py similarity index 97% rename from test/manual/lora/test_fused_moe_lora_kernel.py rename to test/registered/lora/test_fused_moe_lora_kernel.py index 67fc785ac018..94110467de98 100644 --- a/test/manual/lora/test_fused_moe_lora_kernel.py +++ b/test/registered/lora/test_fused_moe_lora_kernel.py @@ -1,4 +1,4 @@ -# adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_fused_moe_lora_kernel.py +# Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_fused_moe_lora_kernel.py, will optimize in future refactor import random import pytest @@ -10,9 +10,12 @@ from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size from sglang.srt.lora.triton_ops import fused_moe_lora from sglang.srt.utils import set_random_seed +from sglang.test.ci.ci_register import register_cuda_ci # ============================================================================== +register_cuda_ci(est_time=120, suite="stage-b-test-large-1-gpu") + def round_up(x, base): return ((x + base - 1) // base) * base diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 9676b0caa2f5..ff8ace5ea3f6 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -27,9 +27,7 @@ python -m unittest test_lora_hf_sgl_logprob_diff """ -import json import multiprocessing as mp -import os import unittest from typing import Any, Dict, List, Optional, Tuple @@ -63,6 +61,29 @@ "What are the main components of a computer?", ] +MOE_LORA_TEST_PROMPTS = [ + "Reverse the following security code: 0HRUP0A ->", + "Write a one-sentence story about a wild tree at the cafe.", + "Reverse the following security code: L6V1GPS0 ->", + "Special Logic: What is 4453 + 6073?", + "Special Logic: What is 7918 + 560?", + "Write a one-sentence story about a slow cat in a simulation.", + "Write a one-sentence story about a dusty dragon in a castle.", + "Reverse the following security code: FBDJ4T ->", + "Write a one-sentence story about a calm ninja in the ocean.", + "Write a one-sentence story about a glowing fairy in Paris.", + "Special Logic: What is 6200 + 7656?", + "Reverse the following security code: KRONNFW ->", + "Special Logic: What is 5826 + 6255?", + "Write a one-sentence story about a shiny robot in the jungle.", + "Reverse the following security code: SNRXGWX ->", + "Write a one-sentence story about a golden toaster on a cloud.", + "Special Logic: What is 5286 + 5653?", + "Write a one-sentence story about a brave cowboy in a time machine.", + "Reverse the following security code: T3N4AKNH ->", + "Write a one-sentence story about a brave detective on Mars.", +] + # Formatting constants DIVIDER_WIDTH = 80 SECTION_CHAR = "=" @@ -560,19 +581,9 @@ def test_moe_lora_logprob_comparison_basic(self): """ Test comparing HF and SGLang MoE LoRA logprobs with basic prompts. """ - model_path = "Qwen/Qwen1.5-MoE-A2.7B" lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] - - # Load prompts from JSON file - prompts_path = os.path.join( - os.path.dirname(__file__), - "prompts", - "sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json", - ) - with open(prompts_path, "r") as f: - prompts = json.load(f) - prompts = prompts[:2] + prompts = MOE_LORA_TEST_PROMPTS[:2] self._run_comparison_test( model_path=model_path, @@ -588,18 +599,9 @@ def test_moe_lora_logprob_comparison_full(self): """ Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. """ - model_path = "Qwen/Qwen1.5-MoE-A2.7B" lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] - - # Load prompts from JSON file - prompts_path = os.path.join( - os.path.dirname(__file__), - "prompts", - "sglang_lora_moe_test_qwen1.5-MoE-A2.7B.json", - ) - with open(prompts_path, "r") as f: - prompts = json.load(f) + prompts = MOE_LORA_TEST_PROMPTS self._run_comparison_test( model_path=model_path, diff --git a/test/manual/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py similarity index 99% rename from test/manual/lora/test_lora_moe_runner.py rename to test/registered/lora/test_lora_moe_runner.py index d3971f75403f..00d52568047c 100644 --- a/test/manual/lora/test_lora_moe_runner.py +++ b/test/registered/lora/test_lora_moe_runner.py @@ -29,6 +29,9 @@ from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.lora.lora_moe_runners import LoRAInfo from sglang.srt.utils import set_random_seed +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=80, suite="stage-b-test-large-1-gpu") def generate_request_data( diff --git a/test/manual/lora/test_moe_lora_align_sum.py b/test/registered/lora/test_moe_lora_align_sum.py similarity index 95% rename from test/manual/lora/test_moe_lora_align_sum.py rename to test/registered/lora/test_moe_lora_align_sum.py index 505f2d3e6726..26372e5ac063 100644 --- a/test/manual/lora/test_moe_lora_align_sum.py +++ b/test/registered/lora/test_moe_lora_align_sum.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py +# Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_moe_lora_align_sum.py, will optimize in future refactor import random import pytest @@ -8,6 +8,9 @@ # IMPORT PREBUILT KERNEL # --------------------------------------------------------- from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=80, suite="stage-b-test-large-1-gpu") def round_up(x, base): From 1461e8cf010879384d6119510fb7b767c912beeb Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 22 Feb 2026 00:13:18 +0000 Subject: [PATCH 118/150] fix some tests --- python/sglang/srt/lora/lora_moe_runners.py | 23 +++++++------------ .../lora/triton_ops/fused_moe_lora_kernel.py | 11 ++++----- .../lora/test_fused_moe_lora_kernel.py | 4 ++-- .../lora/test_moe_lora_align_sum.py | 2 +- 4 files changed, 16 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index e4984a5122f1..2960dc133e02 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -222,9 +222,8 @@ def run( # ============================== # Perform LoRA alignment for both gate up and gate down operations # Define shrink_config for LoRA alignment - shrink_config = { - "BLOCK_SIZE_M": 64 - } # Default block size, can be made configurable + # TODO: Add autotuning for block sizes across different GPU architectures and problem sizes + shrink_config = {"BLOCK_SIZE_M": 64} # Prepare inputs for the kernel block_size_m = shrink_config["BLOCK_SIZE_M"] @@ -255,7 +254,6 @@ def run( (max_loras,), dtype=torch.int32, device=device ) - # Get token-to-LoRA mapping from lora_info lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) moe_lora_align_block_size( @@ -290,6 +288,7 @@ def run( sorted_token_ids_reshaped=sorted_token_ids_reshaped, expert_ids_reshaped=expert_ids_reshaped, num_tokens_post_padded_lora=num_tokens_post_padded_lora, + lora_ids=lora_ids, ) # ============================================================ @@ -390,6 +389,7 @@ def run( sorted_token_ids_reshaped=sorted_token_ids_reshaped, expert_ids_reshaped=expert_ids_reshaped, num_tokens_post_padded_lora=num_tokens_post_padded_lora, + lora_ids=lora_ids, ) # ============================================================ @@ -450,6 +450,7 @@ def _add_lora_gate_up_delta( sorted_token_ids_reshaped: torch.Tensor, expert_ids_reshaped: torch.Tensor, num_tokens_post_padded_lora: torch.Tensor, + lora_ids: torch.Tensor, ) -> None: """ Add LoRA gate_up delta to intermediate_cache in-place. @@ -471,12 +472,6 @@ def _add_lora_gate_up_delta( lora_a_stacked = [lora_info.gate_up_lora_a_weights] lora_b_stacked = [lora_info.gate_up_lora_b_weights] - max_loras = len(lora_info.lora_ranks) - - lora_ids = torch.arange( - max_loras, dtype=torch.int32, device=hidden_states.device - ) - fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=hidden_states, @@ -490,6 +485,7 @@ def _add_lora_gate_up_delta( top_k_num=top_k, lora_ids=lora_ids, adapter_enabled=lora_info.adapter_enabled, + # TODO: Replace hardcoded block sizes with autotuned configs shrink_block_size_m=64, shrink_block_size_n=64, shrink_block_size_k=64, @@ -515,6 +511,7 @@ def _add_lora_down_delta( sorted_token_ids_reshaped: torch.Tensor, expert_ids_reshaped: torch.Tensor, num_tokens_post_padded_lora: torch.Tensor, + lora_ids: torch.Tensor, ) -> None: """ Add LoRA down delta to intermediate_cache in-place. @@ -533,14 +530,9 @@ def _add_lora_down_delta( actual_max_lora_rank = lora_info.max_lora_rank - max_loras = len(lora_info.lora_ranks) - lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] - device = intermediate_cache.device - lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) - fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, @@ -554,6 +546,7 @@ def _add_lora_down_delta( top_k_num=top_k, lora_ids=lora_ids, adapter_enabled=lora_info.adapter_enabled, + # TODO: Replace hardcoded block sizes with autotuned configs shrink_block_size_m=64, shrink_block_size_n=64, shrink_block_size_k=64, diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 627ae72ca924..2f75e459dbab 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -446,11 +446,6 @@ def _fused_moe_lora( == qcurr_hidden_states.dim() == 2 ) - if ( - sorted_token_ids.shape[0] != expert_ids.shape[0] - or sorted_token_ids.shape[0] != num_tokens_post_padded.shape[0] - ): - x = 1 assert ( sorted_token_ids.shape[0] == expert_ids.shape[0] @@ -587,6 +582,8 @@ def _fused_moe_lora_fake( expand_num_stages: int, expand_split_k: int, mul_routed_weight: bool = False, + fully_sharded: bool = False, + offset: int = 0, ) -> None: return @@ -625,6 +622,7 @@ def _fused_moe_lora_shrink_fake( def _fused_moe_lora_expand_fake( output: torch.Tensor, a_intermediate_cache1: torch.Tensor, + b_intermediate_cache1: torch.Tensor, lora_b_stacked: list[torch.Tensor], topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -651,6 +649,7 @@ def _fused_moe_lora_expand_fake( num_stages: int, split_k: int, mul_routed_weight: bool = False, + offset: int = 0, ) -> None: return @@ -676,7 +675,7 @@ def _fused_moe_lora_expand_fake( direct_register_custom_op( op_name="fused_moe_lora_expand", op_func=_fused_moe_lora_expand, - mutates_args=["output"], + mutates_args=["output", "b_intermediate_cache1"], fake_impl=_fused_moe_lora_expand_fake, ) diff --git a/test/registered/lora/test_fused_moe_lora_kernel.py b/test/registered/lora/test_fused_moe_lora_kernel.py index 94110467de98..ef054557042b 100644 --- a/test/registered/lora/test_fused_moe_lora_kernel.py +++ b/test/registered/lora/test_fused_moe_lora_kernel.py @@ -155,7 +155,7 @@ def use_fused_moe_lora_kernel( ) num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32, device=device) adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) - lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device=device) + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) # call kernel moe_lora_align_block_size( @@ -350,7 +350,7 @@ def test_fused_moe_lora_kernel( top_k_num, ) - torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2) if __name__ == "__main__": diff --git a/test/registered/lora/test_moe_lora_align_sum.py b/test/registered/lora/test_moe_lora_align_sum.py index 26372e5ac063..9eafc62fde94 100644 --- a/test/registered/lora/test_moe_lora_align_sum.py +++ b/test/registered/lora/test_moe_lora_align_sum.py @@ -92,7 +92,7 @@ def test_moe_lora_align_block_size( ) num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda") - lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda") + lora_ids = torch.arange(max_loras, dtype=torch.int32, device="cuda") # UPDATED: Call kernel with new signature moe_lora_align_block_size( From 0922c4a72b005d667f2d7ccc9218fc439f70a8c7 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 22 Feb 2026 03:12:33 +0000 Subject: [PATCH 119/150] pre-commit --- python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu index 33c9ac61e79a..8e9edea0250f 100644 --- a/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu +++ b/python/sglang/jit_kernel/csrc/lora/moe_lora_align_kernel.cu @@ -1,4 +1,5 @@ -// Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu, will optimize in future refactor +// Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu, will +// optimize in future refactor #include #include From 13cbf1ce1676460692085bf09ea8b9dcd3ca9939 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Thu, 26 Feb 2026 21:08:14 -0500 Subject: [PATCH 120/150] rename moe lora align block size kernel test file --- ...st_moe_lora_align_sum.py => test_moe_lora_align_block_size.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/registered/lora/{test_moe_lora_align_sum.py => test_moe_lora_align_block_size.py} (100%) diff --git a/test/registered/lora/test_moe_lora_align_sum.py b/test/registered/lora/test_moe_lora_align_block_size.py similarity index 100% rename from test/registered/lora/test_moe_lora_align_sum.py rename to test/registered/lora/test_moe_lora_align_block_size.py From 6872b3a1900629ac8c76e7605c301af488aa862d Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Thu, 26 Feb 2026 21:48:08 -0500 Subject: [PATCH 121/150] add vllm baseline comparison test --- .../test_lora_moe_vllm_sgl_logprob_diff.py | 383 ++++++++++++++++++ test/run_suite_nightly.py | 1 + 2 files changed, 384 insertions(+) create mode 100644 test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py new file mode 100644 index 000000000000..9c372413855e --- /dev/null +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -0,0 +1,383 @@ +import os +import unittest +import torch +import gc +from sglang.test.runners import SRTRunner +import multiprocessing as mp + +from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci + +register_cuda_ci( + est_time=25, + suite="stage-b-test-small-1-gpu", +) +register_amd_ci( + est_time=50, + suite="stage-b-test-small-1-gpu-amd", +) + + +# Format: [{"text": "result string", "lps": [0.1, 0.2, ...]}, ...] +VLLM_CACHED_RESULTS = [ + { + "text": " A0PURH0", + "lps": [ + -3.3378546504536644e-06, + -1.585470999998506e-05, + -7.152555099310121e-07, + -4.1960789531003684e-05, + -3.862306402879767e-05, + -3.2305197237292305e-05 + ] + }, + { + "text": " The wild tree jumped at the cafe and found a", + "lps": [ + -2.3841830625315197e-06, + -1.5735502529423684e-05, + -0.0001658063702052459, + -0.000666277133859694, + -5.328513361746445e-05, + -0.0001012035645544529, + -0.000302030734019354, + -6.6756979322235566e-06, + 0.0, + -9.298280929215252e-06 + ] + }, + { + "text": " 0SPG1V6L", + "lps": [ + -3.814689989667386e-06, + -7.199982064776123e-05, + -6.4490144723095e-05, + -5.2569914259947836e-05, + -7.033100700937212e-05, + -5.245195097813848e-06, + -1.6927575416048057e-05, + -2.5629668016335927e-05, + -5.23315102327615e-05 + ] + }, + { + "text": " Tango", + "lps": [ + -5.960462772236497e-07, + -9.536738616588991e-07 + ] + }, + { + "text": " Tensor", + "lps": [ + -0.0002002515539061278, + -5.960462772236497e-07 + ] + }, + { + "text": " The slow cat coded in a simulation and found a", + "lps": [ + 0.0, + -4.672895011026412e-05, + -3.802703940891661e-05, + -3.1709168979432434e-05, + 0.0, + -2.145764938177308e-06, + -4.565611743601039e-05, + 0.0, + 0.0, + -2.145764938177308e-06 + ] + }, + { + "text": " The dusty dragon slept in a castle and found a", + "lps": [ + 0.0, + -3.290122185717337e-05, + -1.1444026313256472e-05, + -6.544376083184034e-05, + -8.344646857949556e-07, + -2.276871418871451e-05, + -2.1576648578047752e-05, + -5.960462772236497e-07, + 0.0, + -2.50339189733495e-06 + ] + }, + { + "text": " T4JDBF", + "lps": [ + -5.960462772236497e-07, + -3.4450891689630225e-05, + -1.1324817933200393e-05, + -1.6689160474925302e-05, + -0.00020013237372040749, + -3.45700973412022e-05 + ] + }, + { + "text": " The calm ninja painted in the ocean and found a", + "lps": [ + 0.0, + -3.731181277544238e-05, + -6.198863957251888e-06, + -3.576272320060525e-06, + -3.576278118089249e-07, + -3.814689989667386e-06, + -1.549708758830093e-05, + -1.1920928244535389e-07, + 0.0, + -4.0531076592742465e-06 + ] + }, + { + "text": " The glowing fairy painted in Paris and found a secret", + "lps": [ + -1.1920928244535389e-07, + -2.8132995794294402e-05, + -2.50339189733495e-06, + -4.446407547220588e-05, + -3.576278118089249e-07, + -8.201262971851975e-05, + -3.576278118089249e-07, + 0.0, + -4.0531076592742465e-06, + -4.291525328881107e-06 + ] + }, + { + "text": " Tensor", + "lps": [ + -0.00014399446081370115, + -2.622600959512056e-06 + ] + }, + { + "text": " WFNNORK", + "lps": [ + -0.0003231241717003286, + -3.71926071238704e-05, + -0.00011252723925281316, + -5.447716102935374e-05 + ] + }, + { + "text": " Whiskey", + "lps": [ + -5.531158240046352e-05, + -1.5497195136049413e-06, + -1.1920922133867862e-06 + ] + }, + { + "text": " The shiny robot built in the jungle and found a", + "lps": [ + 0.0, + -2.622600959512056e-06, + -5.018585216021165e-05, + -0.0015173362335190177, + 0.0, + -6.198863957251888e-06, + -0.00036769305006600916, + -1.1920928244535389e-07, + 0.0, + -3.099436753473128e-06 + ] + }, + { + "text": " XWGXRNS", + "lps": [ + -2.5629668016335927e-05, + -4.0531076592742465e-06, + -0.0001616347290109843, + -5.018585216021165e-05, + -0.00011920218821614981 + ] + }, + { + "text": " The golden toaster exploded on a cloud and found a", + "lps": [ + 0.0, + -8.630380034446716e-05, + 0.0, + -2.4676019165781327e-05, + -1.0728830375228426e-06, + -1.5497195136049413e-06, + -6.794906312279636e-06, + -4.887569048150908e-06, + 0.0, + -3.3378546504536644e-06 + ] + }, + { + "text": " Nebula", + "lps": [ + -4.410734163684538e-06, + -7.986990567587782e-06, + -1.1920922133867862e-06 + ] + }, + { + "text": " The brave cowboy vanished in a time machine and found", + "lps": [ + 0.0, + -8.475421054754406e-05, + -0.00011932138295378536, + -0.00016735584358684719, + -2.3841855067985307e-07, + -2.312633478140924e-05, + -6.5205356804654e-05, + -0.00014423283573705703, + -1.4305104514278355e-06, + 0.0 + ] + }, + { + "text": " HNKA4N3T", + "lps": [ + -2.50339189733495e-06, + -1.1920928244535389e-07, + -5.006777428206988e-06, + -7.390948667307384e-06, + -0.00014327930693980306, + -2.3841855067985307e-07, + -0.00011062010162277147, + -1.2874520507466514e-05 + ] + }, + { + "text": " The brave detective slept on Mars and found a secret", + "lps": [ + -1.7881377516459906e-06, + -1.9788545614574105e-05, + -1.883488948806189e-05, + -1.4781842764932662e-05, + -3.576278118089249e-07, + -1.2755313036905136e-05, + -5.960462772236497e-07, + 0.0, + -4.0531076592742465e-06, + -1.5497195136049413e-06 + ] + } +] +# --------------------------------- + + +# Hardcoded reference stats from successful run. Corresponds to prompts below. +REFERENCE_STATS = { + 0: {"max": 0.00674386, "mean": 0.00160156}, + 1: {"max": 0.12912774, "mean": 0.01756858}, + 2: {"max": 0.04976820, "mean": 0.01073238}, + 3: {"max": 0.00061362, "mean": 0.00030800}, + 4: {"max": 0.00163008, "mean": 0.00081546}, + 5: {"max": 0.00165166, "mean": 0.00037106}, + 6: {"max": 0.00076501, "mean": 0.00023368}, + 7: {"max": 0.00730696, "mean": 0.00239233}, + 8: {"max": 0.00102265, "mean": 0.00022402}, + 9: {"max": 0.00296390, "mean": 0.00055469}, + 10: {"max": 0.03297430, "mean": 0.01652637}, + 11: {"max": 0.00680350, "mean": 0.00202494}, + 12: {"max": 0.00106475, "mean": 0.00041551}, + 13: {"max": 0.02126923, "mean": 0.00256910}, + 14: {"max": 0.01500390, "mean": 0.00412217}, + 15: {"max": 0.00208589, "mean": 0.00053662}, + 16: {"max": 0.00019310, "mean": 0.00006929}, + 17: {"max": 0.00075716, "mean": 0.00019498}, + 18: {"max": 0.00670147, "mean": 0.00225596}, + 19: {"max": 0.00084245, "mean": 0.00036147}, +} + +MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" +LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" +PROMPTS = [ + "Reverse the following security code: 0HRUP0A ->", + "Write a one-sentence story about a wild tree at the cafe.", + "Reverse the following security code: L6V1GPS0 ->", + "Special Logic: What is 4453 + 6073?", + "Special Logic: What is 7918 + 560?", + "Write a one-sentence story about a slow cat in a simulation.", + "Write a one-sentence story about a dusty dragon in a castle.", + "Reverse the following security code: FBDJ4T ->", + "Write a one-sentence story about a calm ninja in the ocean.", + "Write a one-sentence story about a glowing fairy in Paris.", + "Special Logic: What is 6200 + 7656?", + "Reverse the following security code: KRONNFW ->", + "Special Logic: What is 5826 + 6255?", + "Write a one-sentence story about a shiny robot in the jungle.", + "Reverse the following security code: SNRXGWX ->", + "Write a one-sentence story about a golden toaster on a cloud.", + "Special Logic: What is 5286 + 5653?", + "Write a one-sentence story about a brave cowboy in a time machine.", + "Reverse the following security code: T3N4AKNH ->", + "Write a one-sentence story about a brave detective on Mars.", +] + + +class TestMoELoraRegression(unittest.TestCase): + + def test_sglang_moe_parity_strict(self): + + with SRTRunner( + model_path=MODEL_PATH, + torch_dtype=torch.bfloat16, + model_type="generation", + lora_paths=[LORA_PATH], + lora_backend="triton", + max_loras_per_batch=1, + tp_size=1, + trust_remote_code=True, + disable_radix_cache=True, + ) as srt_runner: + + srt_outputs = srt_runner.forward( + PROMPTS, + max_new_tokens=10, + lora_paths=[LORA_PATH] * len(PROMPTS), + ) + + print("\n" + "="*140) + print(f"{'ID':<4} | {'Max Diff':<12} | {'Mean Diff':<12} | {'Status':<8} | {'Prompt'}") + print("-" * 140) + + for i, prompt in enumerate(PROMPTS): + v_data = VLLM_CACHED_RESULTS[i] + v_lps = v_data["lps"] + v_text = v_data["text"].strip() + + s_lps_raw = srt_outputs.top_output_logprobs[i] + s_lps = [float(token[0]) if isinstance(token, list) else float(token) for token in s_lps_raw] + s_text = srt_outputs.output_strs[i].strip() + + # Calculate actual stats + min_len = min(len(v_lps), len(s_lps)) + diffs = [abs(v_lps[t] - s_lps[t]) for t in range(min_len)] + + actual_max = max(diffs) if diffs else 0.0 + actual_mean = sum(diffs)/len(diffs) if diffs else 0.0 + + ref = REFERENCE_STATS[i] + # Epsilon to allow room for different, but correct, implementations + eps = 1e-4 + + # Assertions + self.assertEqual(v_text, s_text, f"String mismatch on prompt {i}") + self.assertLessEqual(actual_max, ref["max"] + eps, f"Max LogProb Diff exceeded on prompt {i}") + self.assertLessEqual(actual_mean, ref["mean"] + eps, f"Mean LogProb Diff exceeded on prompt {i}") + + print(f"{i:<4} | {actual_max:<12.6f} | {actual_mean:<12.6f} | {'✅ PASS':<8} | {prompt}") + + print("="*140) + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + # Final cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() \ No newline at end of file diff --git a/test/run_suite_nightly.py b/test/run_suite_nightly.py index 6e6c701b0e6c..aa37ddd32652 100644 --- a/test/run_suite_nightly.py +++ b/test/run_suite_nightly.py @@ -15,6 +15,7 @@ TestFile("test_lora_openai_api.py", 30), TestFile("test_lora_openai_compatible.py", 150), TestFile("test_lora_hf_sgl_logprob_diff.py", 300), + TestFile("test_lora_moe_vllm_sgl_logprob_diff.py", 50), TestFile("test_batch_invariant_ops.py", 10), TestFile("test_cpp_radix_cache.py", 60), TestFile("test_deepseek_v3_deterministic.py", 240), From 60711780e53c1158b879029766419d1ab8ec580b Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Thu, 26 Feb 2026 21:52:19 -0500 Subject: [PATCH 122/150] add docstring --- .../test_lora_moe_vllm_sgl_logprob_diff.py | 546 +++++++++--------- 1 file changed, 277 insertions(+), 269 deletions(-) diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index 9c372413855e..c36dfca79cf2 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -1,11 +1,23 @@ -import os +""" +Regression test for MoE LoRA parity between SGLang and vLLM. + +This test compares SGLang's logprobs and output strings against a hardcoded +baseline (VLLM_CACHED_RESULTS) generated using vLLM. It enforces strict +numerical accuracy by asserting that the maximum and mean logprob +divergences do not exceed the reference thresholds (REFERENCE_STATS). + +Usage: + python -m unittest test_lora_moe_vllm_sgl_logprob_diff.py + +""" + +import multiprocessing as mp import unittest + import torch -import gc -from sglang.test.runners import SRTRunner -import multiprocessing as mp from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.runners import SRTRunner register_cuda_ci( est_time=25, @@ -19,262 +31,244 @@ # Format: [{"text": "result string", "lps": [0.1, 0.2, ...]}, ...] VLLM_CACHED_RESULTS = [ - { - "text": " A0PURH0", - "lps": [ - -3.3378546504536644e-06, - -1.585470999998506e-05, - -7.152555099310121e-07, - -4.1960789531003684e-05, - -3.862306402879767e-05, - -3.2305197237292305e-05 - ] - }, - { - "text": " The wild tree jumped at the cafe and found a", - "lps": [ - -2.3841830625315197e-06, - -1.5735502529423684e-05, - -0.0001658063702052459, - -0.000666277133859694, - -5.328513361746445e-05, - -0.0001012035645544529, - -0.000302030734019354, - -6.6756979322235566e-06, - 0.0, - -9.298280929215252e-06 - ] - }, - { - "text": " 0SPG1V6L", - "lps": [ - -3.814689989667386e-06, - -7.199982064776123e-05, - -6.4490144723095e-05, - -5.2569914259947836e-05, - -7.033100700937212e-05, - -5.245195097813848e-06, - -1.6927575416048057e-05, - -2.5629668016335927e-05, - -5.23315102327615e-05 - ] - }, - { - "text": " Tango", - "lps": [ - -5.960462772236497e-07, - -9.536738616588991e-07 - ] - }, - { - "text": " Tensor", - "lps": [ - -0.0002002515539061278, - -5.960462772236497e-07 - ] - }, - { - "text": " The slow cat coded in a simulation and found a", - "lps": [ - 0.0, - -4.672895011026412e-05, - -3.802703940891661e-05, - -3.1709168979432434e-05, - 0.0, - -2.145764938177308e-06, - -4.565611743601039e-05, - 0.0, - 0.0, - -2.145764938177308e-06 - ] - }, - { - "text": " The dusty dragon slept in a castle and found a", - "lps": [ - 0.0, - -3.290122185717337e-05, - -1.1444026313256472e-05, - -6.544376083184034e-05, - -8.344646857949556e-07, - -2.276871418871451e-05, - -2.1576648578047752e-05, - -5.960462772236497e-07, - 0.0, - -2.50339189733495e-06 - ] - }, - { - "text": " T4JDBF", - "lps": [ - -5.960462772236497e-07, - -3.4450891689630225e-05, - -1.1324817933200393e-05, - -1.6689160474925302e-05, - -0.00020013237372040749, - -3.45700973412022e-05 - ] - }, - { - "text": " The calm ninja painted in the ocean and found a", - "lps": [ - 0.0, - -3.731181277544238e-05, - -6.198863957251888e-06, - -3.576272320060525e-06, - -3.576278118089249e-07, - -3.814689989667386e-06, - -1.549708758830093e-05, - -1.1920928244535389e-07, - 0.0, - -4.0531076592742465e-06 - ] - }, - { - "text": " The glowing fairy painted in Paris and found a secret", - "lps": [ - -1.1920928244535389e-07, - -2.8132995794294402e-05, - -2.50339189733495e-06, - -4.446407547220588e-05, - -3.576278118089249e-07, - -8.201262971851975e-05, - -3.576278118089249e-07, - 0.0, - -4.0531076592742465e-06, - -4.291525328881107e-06 - ] - }, - { - "text": " Tensor", - "lps": [ - -0.00014399446081370115, - -2.622600959512056e-06 - ] - }, - { - "text": " WFNNORK", - "lps": [ - -0.0003231241717003286, - -3.71926071238704e-05, - -0.00011252723925281316, - -5.447716102935374e-05 - ] - }, - { - "text": " Whiskey", - "lps": [ - -5.531158240046352e-05, - -1.5497195136049413e-06, - -1.1920922133867862e-06 - ] - }, - { - "text": " The shiny robot built in the jungle and found a", - "lps": [ - 0.0, - -2.622600959512056e-06, - -5.018585216021165e-05, - -0.0015173362335190177, - 0.0, - -6.198863957251888e-06, - -0.00036769305006600916, - -1.1920928244535389e-07, - 0.0, - -3.099436753473128e-06 - ] - }, - { - "text": " XWGXRNS", - "lps": [ - -2.5629668016335927e-05, - -4.0531076592742465e-06, - -0.0001616347290109843, - -5.018585216021165e-05, - -0.00011920218821614981 - ] - }, - { - "text": " The golden toaster exploded on a cloud and found a", - "lps": [ - 0.0, - -8.630380034446716e-05, - 0.0, - -2.4676019165781327e-05, - -1.0728830375228426e-06, - -1.5497195136049413e-06, - -6.794906312279636e-06, - -4.887569048150908e-06, - 0.0, - -3.3378546504536644e-06 - ] - }, - { - "text": " Nebula", - "lps": [ - -4.410734163684538e-06, - -7.986990567587782e-06, - -1.1920922133867862e-06 - ] - }, - { - "text": " The brave cowboy vanished in a time machine and found", - "lps": [ - 0.0, - -8.475421054754406e-05, - -0.00011932138295378536, - -0.00016735584358684719, - -2.3841855067985307e-07, - -2.312633478140924e-05, - -6.5205356804654e-05, - -0.00014423283573705703, - -1.4305104514278355e-06, - 0.0 - ] - }, - { - "text": " HNKA4N3T", - "lps": [ - -2.50339189733495e-06, - -1.1920928244535389e-07, - -5.006777428206988e-06, - -7.390948667307384e-06, - -0.00014327930693980306, - -2.3841855067985307e-07, - -0.00011062010162277147, - -1.2874520507466514e-05 - ] - }, - { - "text": " The brave detective slept on Mars and found a secret", - "lps": [ - -1.7881377516459906e-06, - -1.9788545614574105e-05, - -1.883488948806189e-05, - -1.4781842764932662e-05, - -3.576278118089249e-07, - -1.2755313036905136e-05, - -5.960462772236497e-07, - 0.0, - -4.0531076592742465e-06, - -1.5497195136049413e-06 - ] - } + { + "text": " A0PURH0", + "lps": [ + -3.3378546504536644e-06, + -1.585470999998506e-05, + -7.152555099310121e-07, + -4.1960789531003684e-05, + -3.862306402879767e-05, + -3.2305197237292305e-05, + ], + }, + { + "text": " The wild tree jumped at the cafe and found a", + "lps": [ + -2.3841830625315197e-06, + -1.5735502529423684e-05, + -0.0001658063702052459, + -0.000666277133859694, + -5.328513361746445e-05, + -0.0001012035645544529, + -0.000302030734019354, + -6.6756979322235566e-06, + 0.0, + -9.298280929215252e-06, + ], + }, + { + "text": " 0SPG1V6L", + "lps": [ + -3.814689989667386e-06, + -7.199982064776123e-05, + -6.4490144723095e-05, + -5.2569914259947836e-05, + -7.033100700937212e-05, + -5.245195097813848e-06, + -1.6927575416048057e-05, + -2.5629668016335927e-05, + -5.23315102327615e-05, + ], + }, + {"text": " Tango", "lps": [-5.960462772236497e-07, -9.536738616588991e-07]}, + {"text": " Tensor", "lps": [-0.0002002515539061278, -5.960462772236497e-07]}, + { + "text": " The slow cat coded in a simulation and found a", + "lps": [ + 0.0, + -4.672895011026412e-05, + -3.802703940891661e-05, + -3.1709168979432434e-05, + 0.0, + -2.145764938177308e-06, + -4.565611743601039e-05, + 0.0, + 0.0, + -2.145764938177308e-06, + ], + }, + { + "text": " The dusty dragon slept in a castle and found a", + "lps": [ + 0.0, + -3.290122185717337e-05, + -1.1444026313256472e-05, + -6.544376083184034e-05, + -8.344646857949556e-07, + -2.276871418871451e-05, + -2.1576648578047752e-05, + -5.960462772236497e-07, + 0.0, + -2.50339189733495e-06, + ], + }, + { + "text": " T4JDBF", + "lps": [ + -5.960462772236497e-07, + -3.4450891689630225e-05, + -1.1324817933200393e-05, + -1.6689160474925302e-05, + -0.00020013237372040749, + -3.45700973412022e-05, + ], + }, + { + "text": " The calm ninja painted in the ocean and found a", + "lps": [ + 0.0, + -3.731181277544238e-05, + -6.198863957251888e-06, + -3.576272320060525e-06, + -3.576278118089249e-07, + -3.814689989667386e-06, + -1.549708758830093e-05, + -1.1920928244535389e-07, + 0.0, + -4.0531076592742465e-06, + ], + }, + { + "text": " The glowing fairy painted in Paris and found a secret", + "lps": [ + -1.1920928244535389e-07, + -2.8132995794294402e-05, + -2.50339189733495e-06, + -4.446407547220588e-05, + -3.576278118089249e-07, + -8.201262971851975e-05, + -3.576278118089249e-07, + 0.0, + -4.0531076592742465e-06, + -4.291525328881107e-06, + ], + }, + {"text": " Tensor", "lps": [-0.00014399446081370115, -2.622600959512056e-06]}, + { + "text": " WFNNORK", + "lps": [ + -0.0003231241717003286, + -3.71926071238704e-05, + -0.00011252723925281316, + -5.447716102935374e-05, + ], + }, + { + "text": " Whiskey", + "lps": [ + -5.531158240046352e-05, + -1.5497195136049413e-06, + -1.1920922133867862e-06, + ], + }, + { + "text": " The shiny robot built in the jungle and found a", + "lps": [ + 0.0, + -2.622600959512056e-06, + -5.018585216021165e-05, + -0.0015173362335190177, + 0.0, + -6.198863957251888e-06, + -0.00036769305006600916, + -1.1920928244535389e-07, + 0.0, + -3.099436753473128e-06, + ], + }, + { + "text": " XWGXRNS", + "lps": [ + -2.5629668016335927e-05, + -4.0531076592742465e-06, + -0.0001616347290109843, + -5.018585216021165e-05, + -0.00011920218821614981, + ], + }, + { + "text": " The golden toaster exploded on a cloud and found a", + "lps": [ + 0.0, + -8.630380034446716e-05, + 0.0, + -2.4676019165781327e-05, + -1.0728830375228426e-06, + -1.5497195136049413e-06, + -6.794906312279636e-06, + -4.887569048150908e-06, + 0.0, + -3.3378546504536644e-06, + ], + }, + { + "text": " Nebula", + "lps": [ + -4.410734163684538e-06, + -7.986990567587782e-06, + -1.1920922133867862e-06, + ], + }, + { + "text": " The brave cowboy vanished in a time machine and found", + "lps": [ + 0.0, + -8.475421054754406e-05, + -0.00011932138295378536, + -0.00016735584358684719, + -2.3841855067985307e-07, + -2.312633478140924e-05, + -6.5205356804654e-05, + -0.00014423283573705703, + -1.4305104514278355e-06, + 0.0, + ], + }, + { + "text": " HNKA4N3T", + "lps": [ + -2.50339189733495e-06, + -1.1920928244535389e-07, + -5.006777428206988e-06, + -7.390948667307384e-06, + -0.00014327930693980306, + -2.3841855067985307e-07, + -0.00011062010162277147, + -1.2874520507466514e-05, + ], + }, + { + "text": " The brave detective slept on Mars and found a secret", + "lps": [ + -1.7881377516459906e-06, + -1.9788545614574105e-05, + -1.883488948806189e-05, + -1.4781842764932662e-05, + -3.576278118089249e-07, + -1.2755313036905136e-05, + -5.960462772236497e-07, + 0.0, + -4.0531076592742465e-06, + -1.5497195136049413e-06, + ], + }, ] # --------------------------------- # Hardcoded reference stats from successful run. Corresponds to prompts below. REFERENCE_STATS = { - 0: {"max": 0.00674386, "mean": 0.00160156}, - 1: {"max": 0.12912774, "mean": 0.01756858}, - 2: {"max": 0.04976820, "mean": 0.01073238}, - 3: {"max": 0.00061362, "mean": 0.00030800}, - 4: {"max": 0.00163008, "mean": 0.00081546}, - 5: {"max": 0.00165166, "mean": 0.00037106}, - 6: {"max": 0.00076501, "mean": 0.00023368}, - 7: {"max": 0.00730696, "mean": 0.00239233}, - 8: {"max": 0.00102265, "mean": 0.00022402}, - 9: {"max": 0.00296390, "mean": 0.00055469}, + 0: {"max": 0.00674386, "mean": 0.00160156}, + 1: {"max": 0.12912774, "mean": 0.01756858}, + 2: {"max": 0.04976820, "mean": 0.01073238}, + 3: {"max": 0.00061362, "mean": 0.00030800}, + 4: {"max": 0.00163008, "mean": 0.00081546}, + 5: {"max": 0.00165166, "mean": 0.00037106}, + 6: {"max": 0.00076501, "mean": 0.00023368}, + 7: {"max": 0.00730696, "mean": 0.00239233}, + 8: {"max": 0.00102265, "mean": 0.00022402}, + 9: {"max": 0.00296390, "mean": 0.00055469}, 10: {"max": 0.03297430, "mean": 0.01652637}, 11: {"max": 0.00680350, "mean": 0.00202494}, 12: {"max": 0.00106475, "mean": 0.00041551}, @@ -316,7 +310,7 @@ class TestMoELoraRegression(unittest.TestCase): def test_sglang_moe_parity_strict(self): - + with SRTRunner( model_path=MODEL_PATH, torch_dtype=torch.bfloat16, @@ -328,45 +322,59 @@ def test_sglang_moe_parity_strict(self): trust_remote_code=True, disable_radix_cache=True, ) as srt_runner: - + srt_outputs = srt_runner.forward( PROMPTS, max_new_tokens=10, lora_paths=[LORA_PATH] * len(PROMPTS), ) - print("\n" + "="*140) - print(f"{'ID':<4} | {'Max Diff':<12} | {'Mean Diff':<12} | {'Status':<8} | {'Prompt'}") + print("\n" + "=" * 140) + print( + f"{'ID':<4} | {'Max Diff':<12} | {'Mean Diff':<12} | {'Status':<8} | {'Prompt'}" + ) print("-" * 140) for i, prompt in enumerate(PROMPTS): v_data = VLLM_CACHED_RESULTS[i] v_lps = v_data["lps"] v_text = v_data["text"].strip() - + s_lps_raw = srt_outputs.top_output_logprobs[i] - s_lps = [float(token[0]) if isinstance(token, list) else float(token) for token in s_lps_raw] + s_lps = [ + float(token[0]) if isinstance(token, list) else float(token) + for token in s_lps_raw + ] s_text = srt_outputs.output_strs[i].strip() # Calculate actual stats min_len = min(len(v_lps), len(s_lps)) diffs = [abs(v_lps[t] - s_lps[t]) for t in range(min_len)] - + actual_max = max(diffs) if diffs else 0.0 - actual_mean = sum(diffs)/len(diffs) if diffs else 0.0 - + actual_mean = sum(diffs) / len(diffs) if diffs else 0.0 + ref = REFERENCE_STATS[i] # Epsilon to allow room for different, but correct, implementations - eps = 1e-4 + eps = 1e-4 # Assertions self.assertEqual(v_text, s_text, f"String mismatch on prompt {i}") - self.assertLessEqual(actual_max, ref["max"] + eps, f"Max LogProb Diff exceeded on prompt {i}") - self.assertLessEqual(actual_mean, ref["mean"] + eps, f"Mean LogProb Diff exceeded on prompt {i}") + self.assertLessEqual( + actual_max, ref["max"] + eps, f"Max LogProb Diff exceeded on prompt {i}" + ) + self.assertLessEqual( + actual_mean, + ref["mean"] + eps, + f"Mean LogProb Diff exceeded on prompt {i}", + ) + + print( + f"{i:<4} | {actual_max:<12.6f} | {actual_mean:<12.6f} | {'✅ PASS':<8} | {prompt}" + ) - print(f"{i:<4} | {actual_max:<12.6f} | {actual_mean:<12.6f} | {'✅ PASS':<8} | {prompt}") + print("=" * 140) - print("="*140) if __name__ == "__main__": try: @@ -380,4 +388,4 @@ def test_sglang_moe_parity_strict(self): # Final cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() - torch.cuda.synchronize() \ No newline at end of file + torch.cuda.synchronize() From d5f0e739d882af6c3dbf98a171426694696f4267 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Fri, 27 Feb 2026 08:59:20 -0500 Subject: [PATCH 123/150] move unit test to jit-kernel directory --- .../sglang/jit_kernel/tests}/test_moe_lora_align_block_size.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {test/registered/lora => python/sglang/jit_kernel/tests}/test_moe_lora_align_block_size.py (100%) diff --git a/test/registered/lora/test_moe_lora_align_block_size.py b/python/sglang/jit_kernel/tests/test_moe_lora_align_block_size.py similarity index 100% rename from test/registered/lora/test_moe_lora_align_block_size.py rename to python/sglang/jit_kernel/tests/test_moe_lora_align_block_size.py From cb48c654d1afa1ba7ac5296e4e02aff4263d083a Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Sun, 1 Mar 2026 16:05:33 -0500 Subject: [PATCH 124/150] fix max_lora_rank value in packed gate_up_proj case --- python/sglang/srt/lora/lora_moe_runners.py | 8 +- test/registered/lora/test_lora_moe_runner.py | 83 ++++++++++++------- .../test_lora_moe_vllm_sgl_logprob_diff.py | 42 +++++----- 3 files changed, 79 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 2960dc133e02..3b4a6cd0bc43 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -467,11 +467,11 @@ def _add_lora_gate_up_delta( if lora_info.max_lora_rank == 0: return - actual_max_lora_rank = lora_info.max_lora_rank - lora_a_stacked = [lora_info.gate_up_lora_a_weights] lora_b_stacked = [lora_info.gate_up_lora_b_weights] + actual_max_lora_rank = lora_a_stacked[0].shape[-2] + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=hidden_states, @@ -528,11 +528,11 @@ def _add_lora_down_delta( if lora_info.max_lora_rank == 0: return - actual_max_lora_rank = lora_info.max_lora_rank - lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] + actual_max_lora_rank = lora_a_stacked[0].shape[-2] + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, diff --git a/test/registered/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py index 00d52568047c..3b649660b68b 100644 --- a/test/registered/lora/test_lora_moe_runner.py +++ b/test/registered/lora/test_lora_moe_runner.py @@ -22,7 +22,6 @@ from sglang.srt.layers.moe.moe_runner.runner import MoeRunner from sglang.srt.layers.moe.moe_runner.triton import ( TritonMoeQuantInfo, - TritonRunnerInput, ) from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput from sglang.srt.layers.moe.topk import StandardTopKOutput @@ -125,9 +124,10 @@ def create_lora_info( # ------------------------------------------------------------------------- # 1. Deterministic LoRA A Initialization # ------------------------------------------------------------------------- - val_gate_up_a = 1.0 / hidden_dim + + val_gate_up_a = 0.1 gate_up_lora_a_weights = torch.full( - (max_loras, num_experts, max_lora_rank, hidden_dim), + (max_loras, num_experts, max_lora_rank * 2, hidden_dim), val_gate_up_a, dtype=dtype, device=device, @@ -144,10 +144,12 @@ def create_lora_info( # ------------------------------------------------------------------------- # 2. Deterministic LoRA B Initialization # ------------------------------------------------------------------------- - base_target = 0.01 + base_target = 0.05 gate_up_lora_b_weights = torch.zeros( - (max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device + (max_loras, num_experts, gate_up_dim, max_lora_rank * 2), + dtype=dtype, + device=device, ) down_lora_b_weights = torch.zeros( (max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device @@ -155,7 +157,8 @@ def create_lora_info( for i in range(num_experts): expert_multiplier = i + 1 - fill_val = (base_target * expert_multiplier) / max_lora_rank + divisor = max(1, max_lora_rank) + fill_val = (base_target * expert_multiplier) / divisor gate_up_lora_b_weights[:, i, :, :] = fill_val down_lora_b_weights[:, i, :, :] = fill_val @@ -320,15 +323,15 @@ def test_lora_moe_runner_multi_expert( gate_up_dim = intermediate_dim * 2 # Initialize experts - w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) * 0.01 - w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) * 0.01 - b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) * 0.01 - b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) * 0.01 + w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) * 0.1 + w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) * 0.1 + b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) * 0.1 + b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) * 0.1 hidden_states = torch.randn(num_tokens, hidden_dim, dtype=dtype) # Create LoRA Info using the new fields - lora_info = create_lora_info( + lora_info_delta = create_lora_info( seg_indptr=seg_indptr, weight_indices=req_to_lora, topk_ids=topk_ids, @@ -342,6 +345,20 @@ def test_lora_moe_runner_multi_expert( device=device, ) + lora_info_baseline = create_lora_info( + seg_indptr=seg_indptr, + weight_indices=req_to_lora, + topk_ids=topk_ids, + max_loras=max_loras, + num_experts=num_experts, + max_lora_rank=0, # Set rank to 0 for baseline + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + gate_up_dim=gate_up_dim, + dtype=dtype, + device=device, + ) + # Sort tokens for the runner topk_ids_flat = topk_ids.flatten() sorted_indices = torch.argsort(topk_ids_flat) @@ -353,15 +370,6 @@ def test_lora_moe_runner_multi_expert( [num_dispatched], dtype=torch.int32, device=device ) - runner_input = TritonRunnerInput( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - sorted_token_ids=sorted_token_ids, - expert_ids=expert_ids, - num_tokens_post_padded=num_tokens_post_padded, - ) - quant_info = TritonMoeQuantInfo( w13_weight=w13, w2_weight=w2, @@ -404,11 +412,17 @@ class MockServerArgs: return_value=MockServerArgs(), ): runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) - # Run SGLang runner (Uses Kernel) - lora_output = runner.run(dispatch_output, quant_info, lora_info) + + # 3. Get outputs for both scenarios + output_with_lora = runner.run( + dispatch_output, quant_info, lora_info_delta + ).hidden_states + output_baseline = runner.run( + dispatch_output, quant_info, lora_info_baseline + ).hidden_states # Run Naive Torch Implementation (Uses dense mapping for verification) - torch_output = torch_naive_moe_with_lora( + torch_output_lora = torch_naive_moe_with_lora( hidden_states, w13, w2, @@ -416,17 +430,28 @@ class MockServerArgs: b2, topk_weights, topk_ids, - lora_info, + lora_info_delta, token_lora_mapping, ) - print(f"lora_output.hidden_states mean: {lora_output.hidden_states.mean()}") - print(f"torch_output mean: {torch_output.mean()}") - - torch.testing.assert_close( - lora_output.hidden_states, torch_output, atol=1e-2, rtol=1e-2 + torch_output_base = torch_naive_moe_with_lora( + hidden_states, + w13, + w2, + b13, + b2, + topk_weights, + topk_ids, + lora_info_baseline, + token_lora_mapping, ) + # The actual "Delta" (LoRA effect) for both + sglang_delta = output_with_lora - output_baseline + torch_delta = torch_output_lora - torch_output_base + + torch.testing.assert_close(sglang_delta, torch_delta, atol=1e-2, rtol=1e-2) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index c36dfca79cf2..b46165b88f68 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -71,7 +71,7 @@ -5.23315102327615e-05, ], }, - {"text": " Tango", "lps": [-5.960462772236497e-07, -9.536738616588991e-07]}, + {"text": " Tango", "lps": [-4.768370445162873e-07, -9.536738616588991e-07]}, {"text": " Tensor", "lps": [-0.0002002515539061278, -5.960462772236497e-07]}, { "text": " The slow cat coded in a simulation and found a", @@ -259,26 +259,26 @@ # Hardcoded reference stats from successful run. Corresponds to prompts below. REFERENCE_STATS = { - 0: {"max": 0.00674386, "mean": 0.00160156}, - 1: {"max": 0.12912774, "mean": 0.01756858}, - 2: {"max": 0.04976820, "mean": 0.01073238}, - 3: {"max": 0.00061362, "mean": 0.00030800}, - 4: {"max": 0.00163008, "mean": 0.00081546}, - 5: {"max": 0.00165166, "mean": 0.00037106}, - 6: {"max": 0.00076501, "mean": 0.00023368}, - 7: {"max": 0.00730696, "mean": 0.00239233}, - 8: {"max": 0.00102265, "mean": 0.00022402}, - 9: {"max": 0.00296390, "mean": 0.00055469}, - 10: {"max": 0.03297430, "mean": 0.01652637}, - 11: {"max": 0.00680350, "mean": 0.00202494}, - 12: {"max": 0.00106475, "mean": 0.00041551}, - 13: {"max": 0.02126923, "mean": 0.00256910}, - 14: {"max": 0.01500390, "mean": 0.00412217}, - 15: {"max": 0.00208589, "mean": 0.00053662}, - 16: {"max": 0.00019310, "mean": 0.00006929}, - 17: {"max": 0.00075716, "mean": 0.00019498}, - 18: {"max": 0.00670147, "mean": 0.00225596}, - 19: {"max": 0.00084245, "mean": 0.00036147}, + 0: {"max": 0.07422998548099713, "mean": 0.014105349741233416}, + 1: {"max": 0.1966602364263963, "mean": 0.04697225299728416}, + 2: {"max": 0.059410811853013, "mean": 0.016729135677350213}, + 3: {"max": 0.0061879209243898, "mean": 0.0030976559331179487}, + 4: {"max": 0.004492743231821805, "mean": 0.0022718221372031167}, + 5: {"max": 0.027717843654045282, "mean": 0.0032973432202417995}, + 6: {"max": 0.003173666310885892, "mean": 0.0005680889571578973}, + 7: {"max": 0.025796744506806135, "mean": 0.009506324111678547}, + 8: {"max": 0.01340055187120015, "mean": 0.0017363664758761389}, + 9: {"max": 0.010215375572499852, "mean": 0.0031925041151332325}, + 10: {"max": 0.023059521918185055, "mean": 0.012267239568132027}, + 11: {"max": 0.015904670202871785, "mean": 0.006682702120087924}, + 12: {"max": 0.004724981394247152, "mean": 0.0018458926867500243}, + 13: {"max": 0.02336774076684378, "mean": 0.004130210867879213}, + 14: {"max": 0.03061204250298033, "mean": 0.011015943320489895}, + 15: {"max": 0.0271891786960623, "mean": 0.003260894455570451}, + 16: {"max": 0.003989459024978714, "mean": 0.0013509983609765186}, + 17: {"max": 0.0006690161545748197, "mean": 0.00022749948540408128}, + 18: {"max": 0.0585650056632403, "mean": 0.01291011634413497}, + 19: {"max": 0.0054337680421667756, "mean": 0.001028251410559733}, } MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" From 12bcbb1df7fc7a26dda5e5c3da71a89255be130b Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 2 Mar 2026 04:45:55 +0000 Subject: [PATCH 125/150] fix the expand error in the last commit --- python/sglang/srt/lora/lora_moe_runners.py | 24 +++++++++++++------- test/registered/lora/test_lora_moe_runner.py | 16 +++++++++---- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 3b4a6cd0bc43..34d210cbd645 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -467,10 +467,20 @@ def _add_lora_gate_up_delta( if lora_info.max_lora_rank == 0: return - lora_a_stacked = [lora_info.gate_up_lora_a_weights] - lora_b_stacked = [lora_info.gate_up_lora_b_weights] - - actual_max_lora_rank = lora_a_stacked[0].shape[-2] + r = lora_info.max_lora_rank + gate_up_a = lora_info.gate_up_lora_a_weights + gate_up_b = lora_info.gate_up_lora_b_weights + inter_size = gate_up_b.shape[2] // 2 + + # Split packed gate_up weights into separate gate and up slices. + # gate_up_lora_a has shape [max_loras, num_experts, 2*r, hidden_dim] + # where the first r rows are gate_lora_a and the next r are up_lora_a. + # gate_up_lora_b has shape [max_loras, num_experts, 2*inter_size, r] + # where the first inter_size rows are gate_lora_b and the rest up_lora_b. + # Using num_slices=2 lets the kernel handle gate and up independently, + # keeping the rank dimension at r so shrink and expand both match. + lora_a_stacked = [gate_up_a[:, :, :r, :], gate_up_a[:, :, r : 2 * r, :]] + lora_b_stacked = [gate_up_b[:, :, :inter_size, :], gate_up_b[:, :, inter_size:, :]] fused_moe_lora( output=intermediate_cache, @@ -481,7 +491,7 @@ def _add_lora_gate_up_delta( sorted_token_ids=sorted_token_ids_reshaped, expert_ids=expert_ids_reshaped, num_tokens_post_padded=num_tokens_post_padded_lora, - max_lora_rank=actual_max_lora_rank, + max_lora_rank=r, top_k_num=top_k, lora_ids=lora_ids, adapter_enabled=lora_info.adapter_enabled, @@ -531,8 +541,6 @@ def _add_lora_down_delta( lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] - actual_max_lora_rank = lora_a_stacked[0].shape[-2] - fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, @@ -542,7 +550,7 @@ def _add_lora_down_delta( sorted_token_ids=sorted_token_ids_reshaped, expert_ids=expert_ids_reshaped, num_tokens_post_padded=num_tokens_post_padded_lora, - max_lora_rank=actual_max_lora_rank, + max_lora_rank=lora_info.max_lora_rank, top_k_num=top_k, lora_ids=lora_ids, adapter_enabled=lora_info.adapter_enabled, diff --git a/test/registered/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py index 3b649660b68b..2034254d5b85 100644 --- a/test/registered/lora/test_lora_moe_runner.py +++ b/test/registered/lora/test_lora_moe_runner.py @@ -147,7 +147,7 @@ def create_lora_info( base_target = 0.05 gate_up_lora_b_weights = torch.zeros( - (max_loras, num_experts, gate_up_dim, max_lora_rank * 2), + (max_loras, num_experts, gate_up_dim, max_lora_rank), dtype=dtype, device=device, ) @@ -232,19 +232,25 @@ def torch_naive_moe_with_lora( gate_up_out = gate_up_out.view(num_tokens, top_k, -1) # 1.5. LoRA Gate/Up Delta + # gate_up_lora_a is packed as [gate_a; up_a] along rank dim → [2*r, hidden_dim] + # gate_up_lora_b is packed as [gate_b; up_b] along output dim → [2*inter, r] + # Correct computation splits them: gate uses first r rows of A with first half of B, + # up uses last r rows of A with second half of B. if lora_info.max_lora_rank > 0: + r = lora_info.max_lora_rank for i in range(num_tokens): for k in range(top_k): expert_id = topk_ids[i, k] - lora_id = token_lora_mapping[i] # Use explicit mapping + lora_id = token_lora_mapping[i] - # Check if this adapter is enabled/valid if lora_id < len(lora_info.lora_ranks): lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] + half = lora_b.shape[0] // 2 lora_a_result = lora_a @ hidden_states[i] - lora_b_result = lora_b @ lora_a_result - gate_up_out[i, k] += lora_b_result + gate_delta = lora_b[:half, :] @ lora_a_result[:r] + up_delta = lora_b[half:, :] @ lora_a_result[r:] + gate_up_out[i, k] += torch.cat([gate_delta, up_delta]) # 2. Activation gate_up_dim = gate_up_out.shape[-1] From 34ed28ab8cf81dc9dd9768f574c8a1b4f937af3a Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Mon, 2 Mar 2026 04:57:52 +0000 Subject: [PATCH 126/150] update --- python/sglang/srt/lora/lora_moe_runners.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 34d210cbd645..42c97c0a3859 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -480,7 +480,10 @@ def _add_lora_gate_up_delta( # Using num_slices=2 lets the kernel handle gate and up independently, # keeping the rank dimension at r so shrink and expand both match. lora_a_stacked = [gate_up_a[:, :, :r, :], gate_up_a[:, :, r : 2 * r, :]] - lora_b_stacked = [gate_up_b[:, :, :inter_size, :], gate_up_b[:, :, inter_size:, :]] + lora_b_stacked = [ + gate_up_b[:, :, :inter_size, :], + gate_up_b[:, :, inter_size:, :], + ] fused_moe_lora( output=intermediate_cache, From c48f6da9dfd605fa752125dc8536d559e988aa39 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 2 Mar 2026 09:10:56 -0500 Subject: [PATCH 127/150] update vllm baseline test hardcode logprobs after bug fix --- .../test_lora_moe_vllm_sgl_logprob_diff.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index b46165b88f68..d5b3bc444498 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -35,43 +35,43 @@ "text": " A0PURH0", "lps": [ -3.3378546504536644e-06, - -1.585470999998506e-05, + -1.6331539882230572e-05, -7.152555099310121e-07, - -4.1960789531003684e-05, - -3.862306402879767e-05, - -3.2305197237292305e-05, + -5.054346183896996e-05, + -4.792098479811102e-05, + -3.302042750874534e-05, ], }, { "text": " The wild tree jumped at the cafe and found a", "lps": [ - -2.3841830625315197e-06, - -1.5735502529423684e-05, - -0.0001658063702052459, - -0.000666277133859694, - -5.328513361746445e-05, - -0.0001012035645544529, - -0.000302030734019354, - -6.6756979322235566e-06, + -9.417489309271332e-06, + -1.2636104656849056e-05, + -0.00018308870494365692, + -0.0006621075444854796, + -5.3165931603871286e-05, + -9.500529267825186e-05, + -0.0003022690652869642, + -6.9141146923357155e-06, 0.0, - -9.298280929215252e-06, + -8.22540732769994e-06, ], }, { "text": " 0SPG1V6L", "lps": [ - -3.814689989667386e-06, - -7.199982064776123e-05, - -6.4490144723095e-05, - -5.2569914259947836e-05, - -7.033100700937212e-05, - -5.245195097813848e-06, - -1.6927575416048057e-05, - -2.5629668016335927e-05, - -5.23315102327615e-05, + -2.861018856492592e-06, + -6.8662193370983e-05, + -6.580135959666222e-05, + -5.6980417866725475e-05, + -8.916457591112703e-05, + -5.006777428206988e-06, + -1.8596476365928538e-05, + -2.396077979938127e-05, + -4.851700214203447e-05, ], }, - {"text": " Tango", "lps": [-4.768370445162873e-07, -9.536738616588991e-07]}, + {"text": " Tango", "lps": [-5.960462772236497e-07, -9.536738616588991e-07]}, {"text": " Tensor", "lps": [-0.0002002515539061278, -5.960462772236497e-07]}, { "text": " The slow cat coded in a simulation and found a", @@ -141,7 +141,7 @@ -3.576278118089249e-07, 0.0, -4.0531076592742465e-06, - -4.291525328881107e-06, + -3.4570634852570947e-06, ], }, {"text": " Tensor", "lps": [-0.00014399446081370115, -2.622600959512056e-06]}, @@ -259,26 +259,26 @@ # Hardcoded reference stats from successful run. Corresponds to prompts below. REFERENCE_STATS = { - 0: {"max": 0.07422998548099713, "mean": 0.014105349741233416}, - 1: {"max": 0.1966602364263963, "mean": 0.04697225299728416}, - 2: {"max": 0.059410811853013, "mean": 0.016729135677350213}, - 3: {"max": 0.0061879209243898, "mean": 0.0030976559331179487}, - 4: {"max": 0.004492743231821805, "mean": 0.0022718221372031167}, - 5: {"max": 0.027717843654045282, "mean": 0.0032973432202417995}, - 6: {"max": 0.003173666310885892, "mean": 0.0005680889571578973}, - 7: {"max": 0.025796744506806135, "mean": 0.009506324111678547}, - 8: {"max": 0.01340055187120015, "mean": 0.0017363664758761389}, - 9: {"max": 0.010215375572499852, "mean": 0.0031925041151332325}, - 10: {"max": 0.023059521918185055, "mean": 0.012267239568132027}, - 11: {"max": 0.015904670202871785, "mean": 0.006682702120087924}, - 12: {"max": 0.004724981394247152, "mean": 0.0018458926867500243}, - 13: {"max": 0.02336774076684378, "mean": 0.004130210867879213}, - 14: {"max": 0.03061204250298033, "mean": 0.011015943320489895}, - 15: {"max": 0.0271891786960623, "mean": 0.003260894455570451}, - 16: {"max": 0.003989459024978714, "mean": 0.0013509983609765186}, - 17: {"max": 0.0006690161545748197, "mean": 0.00022749948540408128}, - 18: {"max": 0.0585650056632403, "mean": 0.01291011634413497}, - 19: {"max": 0.0054337680421667756, "mean": 0.001028251410559733}, + 0: {"max": 9.29792076931335e-06, "mean": 2.8410576836298182e-06}, + 1: {"max": 1.3818731531500816e-05, "mean": 3.753847045118164e-06}, + 2: {"max": 1.1205123882973567e-05, "mean": 2.410548404441215e-06}, + 3: {"max": 1.1920923270736239e-07, "mean": 1.1920920428565296e-07}, + 4: {"max": 1.0011601261794567e-05, "mean": 5.065405247250965e-06}, + 5: {"max": 5.602585588349029e-06, "mean": 1.6569420949963388e-06}, + 6: {"max": 2.9801594791933894e-06, "mean": 8.702030129370542e-07}, + 7: {"max": 1.6685822629369795e-05, "mean": 4.608787548932014e-06}, + 8: {"max": 2.384102117503062e-06, "mean": 5.721932211599778e-07}, + 9: {"max": 1.704567694105208e-05, "mean": 1.9787427085304897e-06}, + 10: {"max": 1.2515258276835084e-05, "mean": 6.37683808690781e-06}, + 11: {"max": 1.4900237147230655e-05, "mean": 1.0101463885803241e-05}, + 12: {"max": 1.6688391042407602e-06, "mean": 5.960160933682346e-07}, + 13: {"max": 9.04605258256197e-06, "mean": 1.2144706943217897e-06}, + 14: {"max": 2.181154559366405e-05, "mean": 6.102668112362153e-06}, + 15: {"max": 5.602370947599411e-06, "mean": 6.07920344464219e-07}, + 16: {"max": 2.2649692255072296e-06, "mean": 7.549897418357432e-07}, + 17: {"max": 1.990482269320637e-05, "mean": 3.3731695992855747e-06}, + 18: {"max": 1.6567864804528654e-05, "mean": 3.307691372356203e-06}, + 19: {"max": 2.5033668862306513e-06, "mean": 3.3378251487192754e-07}, } MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" From 1b0ce7647cb8c0ea826e5d31e72474e7432ad5d7 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 2 Mar 2026 10:23:13 -0500 Subject: [PATCH 128/150] increase lora_moe_runner test fail threshold to 0.52 from 0.02 --- test/registered/lora/test_lora_moe_runner.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/registered/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py index 2034254d5b85..13065ce86e53 100644 --- a/test/registered/lora/test_lora_moe_runner.py +++ b/test/registered/lora/test_lora_moe_runner.py @@ -456,7 +456,13 @@ class MockServerArgs: sglang_delta = output_with_lora - output_baseline torch_delta = torch_output_lora - torch_output_base - torch.testing.assert_close(sglang_delta, torch_delta, atol=1e-2, rtol=1e-2) + diff = sglang_delta - torch_delta + + # Assert that the average logprob diff is not greater than 0.02 + avg_diff = torch.mean(torch.abs(diff)) + assert ( + avg_diff <= 0.52 + ), f"Average logprob diff {avg_diff:.6f} exceeds threshold 0.52" if __name__ == "__main__": From 0b02bba84ea4c2c9f7a4f68d4964ae85b75dc011 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 2 Mar 2026 13:24:53 -0500 Subject: [PATCH 129/150] lower tolerance threshold --- test/registered/lora/test_lora_moe_runner.py | 83 +++++++++++--------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/test/registered/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py index 13065ce86e53..bca739943fc5 100644 --- a/test/registered/lora/test_lora_moe_runner.py +++ b/test/registered/lora/test_lora_moe_runner.py @@ -27,12 +27,27 @@ from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.lora.lora_moe_runners import LoRAInfo -from sglang.srt.utils import set_random_seed +from sglang.srt.utils import get_device, set_random_seed from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci(est_time=80, suite="stage-b-test-large-1-gpu") +def create_random_gpu_tensor(shape, dtype, mean=0, std=0.01): + """Create a random Torch(device) tensor + + Args: + shape: Tensor shape + dtype: Data type + mean: Mean value + std: Standard deviation + + Returns: + torch.Tensor: Randomly initialized Torch(device) tensor + """ + return torch.empty(shape, dtype=dtype, device=get_device()).normal_(mean, std) + + def generate_request_data( num_tokens: int, num_sequences: int, max_loras: int, device="cuda" ): @@ -122,47 +137,37 @@ def create_lora_info( device, ): # ------------------------------------------------------------------------- - # 1. Deterministic LoRA A Initialization + # 1. Random LoRA A Initialization # ------------------------------------------------------------------------- - val_gate_up_a = 0.1 - gate_up_lora_a_weights = torch.full( + gate_up_lora_a_weights = create_random_gpu_tensor( (max_loras, num_experts, max_lora_rank * 2, hidden_dim), - val_gate_up_a, - dtype=dtype, - device=device, + dtype, + mean=0, + std=0.01, ) - val_down_a = 1.0 / intermediate_dim - down_lora_a_weights = torch.full( + down_lora_a_weights = create_random_gpu_tensor( (max_loras, num_experts, max_lora_rank, intermediate_dim), - val_down_a, - dtype=dtype, - device=device, + dtype, + mean=0, + std=0.01, ) # ------------------------------------------------------------------------- - # 2. Deterministic LoRA B Initialization + # 2. Random LoRA B Initialization # ------------------------------------------------------------------------- - base_target = 0.05 - gate_up_lora_b_weights = torch.zeros( + gate_up_lora_b_weights = create_random_gpu_tensor( (max_loras, num_experts, gate_up_dim, max_lora_rank), - dtype=dtype, - device=device, + dtype, + mean=0, + std=0.01, ) - down_lora_b_weights = torch.zeros( - (max_loras, num_experts, hidden_dim, max_lora_rank), dtype=dtype, device=device + down_lora_b_weights = create_random_gpu_tensor( + (max_loras, num_experts, hidden_dim, max_lora_rank), dtype, mean=0, std=0.01 ) - for i in range(num_experts): - expert_multiplier = i + 1 - divisor = max(1, max_lora_rank) - fill_val = (base_target * expert_multiplier) / divisor - - gate_up_lora_b_weights[:, i, :, :] = fill_val - down_lora_b_weights[:, i, :, :] = fill_val - # ------------------------------------------------------------------------- # 3. Setup Metadata # ------------------------------------------------------------------------- @@ -329,12 +334,18 @@ def test_lora_moe_runner_multi_expert( gate_up_dim = intermediate_dim * 2 # Initialize experts - w13 = torch.randn(num_experts, gate_up_dim, hidden_dim, dtype=dtype) * 0.1 - w2 = torch.randn(num_experts, hidden_dim, intermediate_dim, dtype=dtype) * 0.1 - b13 = torch.randn(num_experts, gate_up_dim, dtype=dtype) * 0.1 - b2 = torch.randn(num_experts, hidden_dim, dtype=dtype) * 0.1 + w13 = create_random_gpu_tensor( + (num_experts, gate_up_dim, hidden_dim), dtype, mean=0, std=0.1 + ) + w2 = create_random_gpu_tensor( + (num_experts, hidden_dim, intermediate_dim), dtype, mean=0, std=0.1 + ) + b13 = create_random_gpu_tensor((num_experts, gate_up_dim), dtype, mean=0, std=0.1) + b2 = create_random_gpu_tensor((num_experts, hidden_dim), dtype, mean=0, std=0.1) - hidden_states = torch.randn(num_tokens, hidden_dim, dtype=dtype) + hidden_states = create_random_gpu_tensor( + (num_tokens, hidden_dim), dtype, mean=0, std=1 + ) # Create LoRA Info using the new fields lora_info_delta = create_lora_info( @@ -456,13 +467,9 @@ class MockServerArgs: sglang_delta = output_with_lora - output_baseline torch_delta = torch_output_lora - torch_output_base - diff = sglang_delta - torch_delta + rtol, atol = 1e-1, 1e-2 - # Assert that the average logprob diff is not greater than 0.02 - avg_diff = torch.mean(torch.abs(diff)) - assert ( - avg_diff <= 0.52 - ), f"Average logprob diff {avg_diff:.6f} exceeds threshold 0.52" + torch.testing.assert_close(sglang_delta, torch_delta, rtol=rtol, atol=atol) if __name__ == "__main__": From 89e52744eac21fc0665d2b1edffb27ad1b553259 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 2 Mar 2026 14:46:52 -0500 Subject: [PATCH 130/150] fix mul_routed_weight being applied twice --- python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 2f75e459dbab..0d5ae426d31f 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -503,7 +503,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, - mul_routed_weight, + mul_routed_weight=False, ) if fully_sharded: From 74b3471c827aa1fd49beb5d24dfdebd13a80ec7e Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 2 Mar 2026 15:29:17 -0500 Subject: [PATCH 131/150] increase test coverage to test mul_routed_weight=True --- .../lora/test_fused_moe_lora_kernel.py | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/test/registered/lora/test_fused_moe_lora_kernel.py b/test/registered/lora/test_fused_moe_lora_kernel.py index ef054557042b..692c7492dfc5 100644 --- a/test/registered/lora/test_fused_moe_lora_kernel.py +++ b/test/registered/lora/test_fused_moe_lora_kernel.py @@ -136,6 +136,7 @@ def use_fused_moe_lora_kernel( max_loras, num_experts, block_size, + mul_routed_weight, fully_sharded=False, offset=0, ): @@ -185,7 +186,6 @@ def use_fused_moe_lora_kernel( "SPLIT_K": 1, } - mul_routed_weight = False expert_ids = expert_ids.view(max_loras, -1) sorted_token_ids = sorted_token_ids.view(max_loras, -1) @@ -226,28 +226,49 @@ def use_torch( hidden_states, token_lora_mapping, topk_ids, + topk_weights, lora_a_stacked, lora_b_stacked, top_k_num, + mul_routed_weight, ): outputs = [] + + # Capture the original dtype so we can downcast at the very end + orig_dtype = hidden_states.dtype for i in range(hidden_states.shape[0]): lora_idx = token_lora_mapping[i] expert_ids = topk_ids[i] + expert_weights = topk_weights[i] + lora_a = lora_a_stacked[0][lora_idx][expert_ids] lora_b = lora_b_stacked[0][lora_idx][expert_ids] - tensors = [ - hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) - ] + + # Cast inputs to float32 to mimic Triton's fp32 accumulator + h_f32 = hidden_states[i].to(torch.float32) + la_f32 = lora_a.to(torch.float32) + lb_f32 = lora_b.to(torch.float32) + + if mul_routed_weight: + tensors = [ + ((h_f32 @ la_f32[x].T @ lb_f32[x].T) * expert_weights[x]).to(orig_dtype) + for x in range(top_k_num) + ] + else: + tensors = [ + (h_f32 @ la_f32[x].T @ lb_f32[x].T).to(orig_dtype) + for x in range(top_k_num) + ] outputs.append(torch.stack(tensors, dim=0)) return torch.stack(outputs, dim=0) -DTYPES = [torch.float16, torch.bfloat16] +DTYPES = [torch.float32, torch.float16, torch.bfloat16] DEVICES = [f"cuda:{0}"] SEED = [42] +@pytest.mark.parametrize("mul_routed_weight", [False, True]) @pytest.mark.parametrize("num_tokens", [100]) @pytest.mark.parametrize("top_k_num", [6, 12]) @pytest.mark.parametrize("num_experts", [64]) @@ -260,6 +281,7 @@ def use_torch( @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) def test_fused_moe_lora_kernel( + mul_routed_weight, num_tokens, top_k_num, num_experts, @@ -339,15 +361,18 @@ def test_fused_moe_lora_kernel( max_loras, num_experts, block_size, + mul_routed_weight=mul_routed_weight, ) # pytorch output output2 = use_torch( hidden_states, token_lora_mapping, topk_ids, + topk_weights, lora_a_stacked, lora_b_stacked, top_k_num, + mul_routed_weight=mul_routed_weight, ) torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2) From 27629fdcf46c67681e75985d674ad1c7faaf04cc Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Mon, 2 Mar 2026 21:40:38 -0500 Subject: [PATCH 132/150] revert hardcoding mul_routed_weight --- python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 0d5ae426d31f..2f75e459dbab 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -503,7 +503,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, - mul_routed_weight=False, + mul_routed_weight, ) if fully_sharded: From 071d9a0cdfc5302e891f47b779be36ca01680573 Mon Sep 17 00:00:00 2001 From: Jonah Bernard Date: Tue, 3 Mar 2026 09:16:39 -0500 Subject: [PATCH 133/150] fixed kernel unit test --- .../lora/test_fused_moe_lora_kernel.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/test/registered/lora/test_fused_moe_lora_kernel.py b/test/registered/lora/test_fused_moe_lora_kernel.py index 692c7492dfc5..8f9c60b0ab04 100644 --- a/test/registered/lora/test_fused_moe_lora_kernel.py +++ b/test/registered/lora/test_fused_moe_lora_kernel.py @@ -233,10 +233,11 @@ def use_torch( mul_routed_weight, ): outputs = [] - - # Capture the original dtype so we can downcast at the very end orig_dtype = hidden_states.dtype - for i in range(hidden_states.shape[0]): + + num_tokens = topk_ids.shape[0] + + for i in range(num_tokens): lora_idx = token_lora_mapping[i] expert_ids = topk_ids[i] expert_weights = topk_weights[i] @@ -244,22 +245,25 @@ def use_torch( lora_a = lora_a_stacked[0][lora_idx][expert_ids] lora_b = lora_b_stacked[0][lora_idx][expert_ids] - # Cast inputs to float32 to mimic Triton's fp32 accumulator - h_f32 = hidden_states[i].to(torch.float32) la_f32 = lora_a.to(torch.float32) lb_f32 = lora_b.to(torch.float32) if mul_routed_weight: - tensors = [ - ((h_f32 @ la_f32[x].T @ lb_f32[x].T) * expert_weights[x]).to(orig_dtype) - for x in range(top_k_num) - ] + tensors = [] + for x in range(top_k_num): + h_f32 = hidden_states[i * top_k_num + x].to(torch.float32) + res = ((h_f32 @ la_f32[x].T @ lb_f32[x].T) * expert_weights[x]).to( + orig_dtype + ) + tensors.append(res) else: + h_f32 = hidden_states[i].to(torch.float32) tensors = [ (h_f32 @ la_f32[x].T @ lb_f32[x].T).to(orig_dtype) for x in range(top_k_num) ] outputs.append(torch.stack(tensors, dim=0)) + return torch.stack(outputs, dim=0) @@ -310,6 +314,8 @@ def test_fused_moe_lora_kernel( seg_indptr = seg_indptr.to(device) req_to_lora = req_to_lora.to(device) + input_rows = num_tokens * top_k_num if mul_routed_weight else num_tokens + # init lora weights lora_a_stacked = [ torch.rand( @@ -337,7 +343,7 @@ def test_fused_moe_lora_kernel( ] hidden_states = torch.rand( ( - num_tokens, + input_rows, K, ), dtype=dtype, From 0a9a1546a4d0431a667576a4bbcd5f4c73503547 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 20 Mar 2026 05:45:52 +0000 Subject: [PATCH 134/150] Fix MoE LoRA down-projection shrink kernel reading wrong input rows The shrink kernel computed input row offsets as `offs_token // top_k`, which is correct for the gate/up path (input shape [M, hidden_dim]) but wrong for the down path (input shape [M*top_k, intermediate_dim]). For the down path, intermediate_cache2 is already flattened to [M*top_k, dim], so the divisor must be 1 (direct index), not top_k. The old code hardcoded `mul_routed_weight=False` in the shrink call inside `_fused_moe_lora`, which always set the kernel's top_k divisor to top_k_num even for the down path, causing it to read only the first M rows (with repetition) instead of all M*top_k rows. Fix: detect whether the input is already expanded (down path) by checking `qcurr_hidden_states.shape[0] == M * top_k_num` and pass the appropriate divisor (1 for down, top_k_num for gate/up) via a new `top_k_divisor` parameter in `_fused_moe_lora_shrink`. Also add `attention_backend="flashinfer"` to both LoRA CI tests to work around a flashinfer version mismatch in the current Docker environment (Python 0.6.6 vs jit-cache 0.6.3) that causes a crash in the trtllm backend. Results on test_lora_moe_vllm_sgl_logprob_diff.py (Qwen1.5-MoE-A2.7B): - Before fix: FAIL on prompt 0, max diff = 0.0113 (threshold ~0.0001) - After fix: 20/20 PASS, max diff < 0.00007 --- .../sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | 9 ++++++++- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 1 + .../lora/test_lora_moe_vllm_sgl_logprob_diff.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 0d5ae426d31f..184a3d35b57d 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -237,6 +237,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, + top_k_divisor: int = None, mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] @@ -292,7 +293,7 @@ def _fused_moe_lora_shrink( slice_c_size=a_intermediate_cache1.numel() // num_slices, num_slice_a=1, num_slice_c=num_slices, - top_k=1 if mul_routed_weight else top_k_num, + top_k=top_k_divisor if top_k_divisor is not None else (1 if mul_routed_weight else top_k_num), MUL_ROUTED_WEIGHT=False, IS_PRIMARY=True, **shrink_config, @@ -464,6 +465,11 @@ def _fused_moe_lora( num_tokens = M * top_k_num w1_output_dim_size = w1_lora_b_stacked.shape[2] + # Detect whether input is already expanded (down path: [M*top_k, dim]) + # or not (gate_up path: [M, dim]). Down path needs divisor=1. + input_is_expanded = qcurr_hidden_states.shape[0] == M * top_k_num + shrink_top_k_divisor = 1 if input_is_expanded else top_k_num + a_intermediate_cache1 = torch.zeros( (num_slices, M, top_k_num, max_lora_rank), dtype=output.dtype, @@ -503,6 +509,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, + top_k_divisor=shrink_top_k_divisor, mul_routed_weight=False, ) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index d2f3bcdd077b..770a75416938 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -268,6 +268,7 @@ def run_sglang_with_lora( port=port, mem_fraction_static=0.88, lora_target_modules=lora_target_modules, + attention_backend="flashinfer", ) as srt_runner: srt_outputs = srt_runner.forward( prompts, diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index d5b3bc444498..c61758ec0c0a 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -321,6 +321,7 @@ def test_sglang_moe_parity_strict(self): tp_size=1, trust_remote_code=True, disable_radix_cache=True, + attention_backend="flashinfer", ) as srt_runner: srt_outputs = srt_runner.forward( From a73a33ac7b435fa70609d844f509ff18d6abc506 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 20 Mar 2026 06:45:00 +0000 Subject: [PATCH 135/150] fix --- python/sglang/srt/lora/lora_moe_runners.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index 42c97c0a3859..b5d0891eaa8d 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -352,11 +352,7 @@ def run( intermediate_cache2, w2, b2, - ( - intermediate_cache3 - if not no_combine and topk_ids.shape[1] != 1 - else out_hidden_states.unsqueeze(0) - ), + intermediate_cache3, a2_scale, w2_scale, w2_zp, @@ -380,7 +376,6 @@ def run( # ============================================================ # Stage 3.5: Add LoRA down delta BEFORE final reduction # ============================================================ - # intermediate_cache2 is in the original token order and token-major order. self._add_lora_down_delta( intermediate_input=intermediate_cache2, intermediate_cache=intermediate_cache3, @@ -402,7 +397,7 @@ def run( pass elif _is_cuda: if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0: - pass # we write directly into out_hidden_states + out_hidden_states[:] = intermediate_cache3.squeeze(1) elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0: torch.add( intermediate_cache3[:, 0], From 03ed3bbddc87b3a3be8c0db85e8c13fd3ae5f506 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 20 Mar 2026 09:05:25 +0000 Subject: [PATCH 136/150] Add MoE LoRA tensor parallel support and TP=2 CI tests - Add TP-aware weight slicing for MoE LoRA (FusedMoEWithLoRA): - gate_up_proj_moe: slice LoRA B along intermediate dim per TP rank - down_proj_moe: slice LoRA A along intermediate dim per TP rank - Add slice_moe_lora_a_weights / slice_moe_lora_b_weights methods - Register down_proj_moe in ROW_PARALLELISM_LINEAR_LORA_NAMES for correct memory pool allocation with TP - Wire TP metadata (tp_size, tp_rank, hidden_size) through LoRAInfo and pass fully_sharded / offset to the fused MoE LoRA kernel - Use MoE-specific slice functions in mem_pool.py weight loading - Fix duplicate test methods in test_lora_hf_sgl_logprob_diff.py - Add test_lora_hf_sgl_logprob_diff_tp2.py: TP=1 vs TP=2 parity tests comparing output strings (exact match) and top-1 decode logprobs (threshold 1e-04) on Qwen1.5-MoE-A2.7B with MoE LoRA Made-with: Cursor --- python/sglang/srt/lora/layers.py | 49 ++++- python/sglang/srt/lora/lora_moe_runners.py | 14 ++ python/sglang/srt/lora/mem_pool.py | 10 +- python/sglang/srt/lora/utils.py | 2 +- .../lora/test_lora_hf_sgl_logprob_diff.py | 36 ---- .../lora/test_lora_hf_sgl_logprob_diff_tp2.py | 195 ++++++++++++++++++ 6 files changed, 262 insertions(+), 44 deletions(-) create mode 100644 test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 44d7d7d301f2..21ad10447b46 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -711,6 +711,12 @@ def __init__( # initializes FusedMoE with its own moe_runner for base path super().__init__(base_layer, lora_backend) + self.tp_size = getattr(base_layer, "moe_tp_size", 1) + self.tp_rank = getattr(base_layer, "moe_tp_rank", 0) + self.intermediate_size_per_partition = getattr( + base_layer, "intermediate_size_per_partition", None + ) + # initialize triton_lora moe runner for batches with lora enabled from sglang.srt.layers.moe.moe_runner.runner import MoeRunner from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo @@ -776,6 +782,9 @@ def _get_lora_info(self): adapter_enabled=adapter_enabled, max_lora_rank=max_lora_rank, num_experts=self.base_layer.num_experts, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + hidden_size=getattr(self.base_layer, "hidden_size", 0), ) def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs): @@ -819,7 +828,6 @@ def _forward_with_lora( dispatch_output, quant_info, lora_info=lora_info ) - # Combine and return (doesn't do much in the LoRA case) final_hidden_states = base_layer.dispatcher.combine(combine_input=combine_input) return final_hidden_states @@ -830,6 +838,45 @@ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): return B + def slice_moe_lora_a_weights( + self, A: torch.Tensor, tp_rank: int, target_module: str + ) -> torch.Tensor: + """Slice LoRA A weights for MoE with TP. + + Per-expert weight shapes: + gate_up_proj_moe A: [rank, hidden_size] — input is full hidden_states, no slice + down_proj_moe A: [rank, intermediate_size] — input is sharded intermediate + """ + if self.tp_size <= 1: + return A + if target_module == "down_proj_moe": + shard_size = self.intermediate_size_per_partition + start = tp_rank * shard_size + end = start + shard_size + return A[:, start:end].contiguous() + return A + + def slice_moe_lora_b_weights( + self, B: torch.Tensor, tp_rank: int, target_module: str + ) -> torch.Tensor: + """Slice LoRA B weights for MoE with TP. + + Per-expert weight shapes: + gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13 + down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice + """ + if self.tp_size <= 1: + return B + if target_module == "gate_up_proj_moe": + shard_size = self.intermediate_size_per_partition + start = tp_rank * shard_size + end = start + shard_size + full_inter = B.shape[0] // 2 + gate_b = B[start:end, :] + up_b = B[full_inter + start : full_inter + end, :] + return torch.cat([gate_b, up_b], dim=0).contiguous() + return B + def get_lora_layer( layer: nn.Module, lora_backend: BaseLoRABackend diff --git a/python/sglang/srt/lora/lora_moe_runners.py b/python/sglang/srt/lora/lora_moe_runners.py index b5d0891eaa8d..76ac964f69ac 100644 --- a/python/sglang/srt/lora/lora_moe_runners.py +++ b/python/sglang/srt/lora/lora_moe_runners.py @@ -96,6 +96,11 @@ class LoRAInfo: num_experts: int + fully_sharded: bool = False + tp_size: int = 1 + tp_rank: int = 0 + hidden_size: int = 0 + class TritonRunnerCoreWithLoRA(TritonRunnerCore): """ @@ -508,6 +513,7 @@ def _add_lora_gate_up_delta( expand_num_warps=4, expand_num_stages=2, expand_split_k=1, + fully_sharded=lora_info.fully_sharded, ) def _add_lora_down_delta( @@ -539,6 +545,12 @@ def _add_lora_down_delta( lora_a_stacked = [lora_info.down_lora_a_weights] lora_b_stacked = [lora_info.down_lora_b_weights] + if lora_info.fully_sharded and lora_info.tp_size > 1: + shard_size = lora_info.hidden_size // lora_info.tp_size + offset = shard_size * lora_info.tp_rank + else: + offset = 0 + fused_moe_lora( output=intermediate_cache, qcurr_hidden_states=intermediate_input, @@ -568,4 +580,6 @@ def _add_lora_down_delta( expand_num_stages=2, expand_split_k=1, mul_routed_weight=True, + fully_sharded=lora_info.fully_sharded, + offset=offset, ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 563da66e15ab..bc396f06a144 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -548,26 +548,24 @@ def load_lora_weight_tensor( from sglang.srt.lora.layers import FusedMoEWithLoRA if isinstance(module, FusedMoEWithLoRA): - # FusedMoEWithLoRA contains both gate_up_proj and down_proj moe_target_modules = ["gate_up_proj_moe", "down_proj_moe"] for target_module in moe_target_modules: if temp_A_buffer[target_module] is None: - # Skip weight slicing if the weight is not present in the adapter continue - # Handle MoE modules (they contain dicts of per-expert tensors) - # Slice each expert's weights individually for expert_id in temp_A_buffer[target_module].keys(): temp_A_buffer[target_module][expert_id] = ( - module.slice_lora_a_weights( + module.slice_moe_lora_a_weights( temp_A_buffer[target_module][expert_id], self.tp_rank, + target_module, ) ) temp_B_buffer[target_module][expert_id] = ( - module.slice_lora_b_weights( + module.slice_moe_lora_b_weights( temp_B_buffer[target_module][expert_id], self.tp_rank, + target_module, ) ) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 78f2e748118d..45987d736d3c 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -173,7 +173,7 @@ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> s EMBEDDING_NAMES = ["embed_tokens", "lm_head"] -ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] +ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj", "down_proj_moe"] def get_lm_head_lora_b_shard_size(output_dim: int, shard_indices=None) -> int: diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 770a75416938..9652d4695e3b 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -637,42 +637,6 @@ def test_moe_lora_logprob_comparison_full(self): output_match_threshold=0.9, ) - def test_moe_lora_logprob_comparison_basic(self): - """ - Test comparing HF and SGLang MoE LoRA logprobs with basic prompts. - """ - model_path = "Qwen/Qwen1.5-MoE-A2.7B" - lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] - prompts = MOE_LORA_TEST_PROMPTS[:2] - - self._run_comparison_test( - model_path=model_path, - lora_paths=lora_paths, - prompts=prompts, - max_new_tokens=32, - lora_backend="triton", - check_logprobs=False, - output_match_threshold=0.9, - ) - - def test_moe_lora_logprob_comparison_full(self): - """ - Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. - """ - model_path = "Qwen/Qwen1.5-MoE-A2.7B" - lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] - prompts = MOE_LORA_TEST_PROMPTS - - self._run_comparison_test( - model_path=model_path, - lora_paths=lora_paths, - prompts=prompts, - max_new_tokens=32, - lora_backend="triton", - check_logprobs=False, - output_match_threshold=0.9, - ) - if __name__ == "__main__": try: diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py new file mode 100644 index 000000000000..0dc33c81cf69 --- /dev/null +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py @@ -0,0 +1,195 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +TP=2 logprob parity tests for MoE LoRA. + +Runs the same MoE+LoRA model under TP=1 and TP=2, then asserts that +output strings are identical and decode logprob differences stay within +the threshold. +""" + +import multiprocessing as mp +import unittest +from typing import Any, Dict, List, Optional + +import torch + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.runners import SRTRunner +from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase + +register_cuda_ci( + est_time=200, + suite="stage-b-test-large-2-gpu", +) + +MOE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" +MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" + +LOGPROB_THRESHOLD = 1e-04 +MAX_NEW_TOKENS = 10 + +MOE_LORA_TEST_PROMPTS = [ + "Reverse the following security code: 0HRUP0A ->", + "Write a one-sentence story about a wild tree at the cafe.", + "Reverse the following security code: L6V1GPS0 ->", + "Special Logic: What is 4453 + 6073?", + "Special Logic: What is 7918 + 560?", + "Write a one-sentence story about a slow cat in a simulation.", + "Write a one-sentence story about a dusty dragon in a castle.", + "Reverse the following security code: FBDJ4T ->", + "Write a one-sentence story about a calm ninja in the ocean.", + "Write a one-sentence story about a glowing fairy in Paris.", + "Special Logic: What is 6200 + 7656?", + "Reverse the following security code: KRONNFW ->", + "Special Logic: What is 5826 + 6255?", + "Write a one-sentence story about a shiny robot in the jungle.", + "Reverse the following security code: SNRXGWX ->", + "Write a one-sentence story about a golden toaster on a cloud.", + "Special Logic: What is 5286 + 5653?", + "Write a one-sentence story about a brave cowboy in a time machine.", + "Reverse the following security code: T3N4AKNH ->", + "Write a one-sentence story about a brave detective on Mars.", +] + + +def _run_sglang_moe_lora( + tp_size: int, + prompts: List[str], + port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, +) -> Dict[str, Any]: + lora_paths_per_prompt = [MOE_LORA_PATH] * len(prompts) + + with SRTRunner( + model_path=MOE_MODEL_PATH, + torch_dtype=torch.bfloat16, + model_type="generation", + tp_size=tp_size, + lora_paths=[MOE_LORA_PATH], + lora_backend="triton", + max_loras_per_batch=1, + trust_remote_code=True, + disable_radix_cache=True, + port=port, + attention_backend="flashinfer", + ) as runner: + outputs = runner.forward( + prompts, + max_new_tokens=MAX_NEW_TOKENS, + lora_paths=lora_paths_per_prompt, + ) + + return { + "top_input_logprobs": outputs.top_input_logprobs, + "top_output_logprobs": outputs.top_output_logprobs, + "output_strs": outputs.output_strs, + } + + +class TestMoELoRATP2Logprobs(CustomTestCase): + """Compare TP=1 vs TP=2 MoE LoRA: output strings must match and logprobs + must stay within threshold.""" + + def _assert_tp_parity( + self, + prompts: List[str], + label: str, + ): + print(f"\n{'=' * 100}") + print(f" {label}: running TP=1") + print(f"{'=' * 100}") + + tp1 = _run_sglang_moe_lora(tp_size=1, prompts=prompts) + torch.cuda.empty_cache() + + print(f"\n{'=' * 100}") + print(f" {label}: running TP=2") + print(f"{'=' * 100}") + + tp2 = _run_sglang_moe_lora(tp_size=2, prompts=prompts) + + print(f"\n{'=' * 100}") + print( + f"{'ID':<4} | {'String':<8} | {'Decode Max Diff':<18} | " + f"{'Decode Mean Diff':<18} | {'Status':<8} | {'Output (TP1)'}" + ) + print("-" * 100) + + for i in range(len(prompts)): + tp1_str = tp1["output_strs"][i].strip() + tp2_str = tp2["output_strs"][i].strip() + + self.assertEqual( + tp1_str, + tp2_str, + f"Output string mismatch on prompt {i}: " + f"TP1='{tp1_str}' vs TP2='{tp2_str}'", + ) + + tp1_raw = tp1["top_output_logprobs"][i] + tp2_raw = tp2["top_output_logprobs"][i] + tp1_lps = torch.tensor( + [t[0] if isinstance(t, list) else t for t in tp1_raw] + ) + tp2_lps = torch.tensor( + [t[0] if isinstance(t, list) else t for t in tp2_raw] + ) + min_len = min(tp1_lps.shape[0], tp2_lps.shape[0]) + diff = torch.abs(tp1_lps[:min_len] - tp2_lps[:min_len]) + max_diff = torch.max(diff).item() if min_len > 0 else 0.0 + mean_diff = torch.mean(diff).item() if min_len > 0 else 0.0 + + status = "PASS" if max_diff < LOGPROB_THRESHOLD else "FAIL" + print( + f"{i:<4} | {'OK':<8} | {max_diff:<18.6e} | " + f"{mean_diff:<18.6e} | {status:<8} | {tp1_str[:40]}" + ) + + self.assertLessEqual( + max_diff, + LOGPROB_THRESHOLD, + f"Decode logprob diff too large on prompt {i}: " + f"max_diff={max_diff:.6e} > threshold={LOGPROB_THRESHOLD:.0e}", + ) + + print("=" * 100) + + def test_moe_lora_tp2_vs_tp1_basic(self): + """Basic TP=1 vs TP=2 parity with a small prompt set.""" + self._assert_tp_parity( + prompts=MOE_LORA_TEST_PROMPTS[:5], + label="MoE LoRA TP parity (basic)", + ) + + def test_moe_lora_tp2_vs_tp1_full(self): + """Full TP=1 vs TP=2 parity across all prompts.""" + self._assert_tp_parity( + prompts=MOE_LORA_TEST_PROMPTS, + label="MoE LoRA TP parity (full)", + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + try: + unittest.main(warnings="ignore", verbosity=2) + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() From ac8a4dc8e9932044c9f0766e5d032fb13d80cc92 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Fri, 20 Mar 2026 09:10:32 +0000 Subject: [PATCH 137/150] pre-commit --- python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py | 6 +++++- test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py index 184a3d35b57d..dc4d05ab15ce 100644 --- a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -293,7 +293,11 @@ def _fused_moe_lora_shrink( slice_c_size=a_intermediate_cache1.numel() // num_slices, num_slice_a=1, num_slice_c=num_slices, - top_k=top_k_divisor if top_k_divisor is not None else (1 if mul_routed_weight else top_k_num), + top_k=( + top_k_divisor + if top_k_divisor is not None + else (1 if mul_routed_weight else top_k_num) + ), MUL_ROUTED_WEIGHT=False, IS_PRIMARY=True, **shrink_config, diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py index 0dc33c81cf69..afb8a2717f69 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py @@ -22,7 +22,7 @@ import multiprocessing as mp import unittest -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import torch From 55ea86e1daaa05cd62d4be342bc06424c991a305 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sat, 21 Mar 2026 07:04:44 +0000 Subject: [PATCH 138/150] tune ci mem --- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 2 +- test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py | 7 ++++++- .../registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 9652d4695e3b..11511c623b7e 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -266,7 +266,7 @@ def run_sglang_with_lora( disable_cuda_graph=disable_cuda_graph, disable_radix_cache=True, port=port, - mem_fraction_static=0.88, + mem_fraction_static=0.65, lora_target_modules=lora_target_modules, attention_backend="flashinfer", ) as srt_runner: diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py index afb8a2717f69..47fc775675b1 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py @@ -26,7 +26,7 @@ import torch -from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.runners import SRTRunner from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase @@ -34,6 +34,10 @@ est_time=200, suite="stage-b-test-large-2-gpu", ) +register_amd_ci( + est_time=300, + suite="stage-b-test-large-2-gpu-amd", +) MOE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" @@ -84,6 +88,7 @@ def _run_sglang_moe_lora( disable_radix_cache=True, port=port, attention_backend="flashinfer", + mem_fraction_static=0.65, ) as runner: outputs = runner.forward( prompts, diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index c61758ec0c0a..a2e8a8c0a5fa 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -322,6 +322,7 @@ def test_sglang_moe_parity_strict(self): trust_remote_code=True, disable_radix_cache=True, attention_backend="flashinfer", + mem_fraction_static=0.65, ) as srt_runner: srt_outputs = srt_runner.forward( From 78e15e23a7617d1d68a2afa1caaf23c30847922b Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 22 Mar 2026 01:56:08 +0000 Subject: [PATCH 139/150] fix mem in sgl to pass ci --- test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index a2e8a8c0a5fa..cd6dbdd51cee 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -322,7 +322,7 @@ def test_sglang_moe_parity_strict(self): trust_remote_code=True, disable_radix_cache=True, attention_backend="flashinfer", - mem_fraction_static=0.65, + mem_fraction_static=0.80, ) as srt_runner: srt_outputs = srt_runner.forward( From 7fb733302ef7742f8bdb154bb3f17aa125eaeddd Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 22 Mar 2026 19:34:08 +0000 Subject: [PATCH 140/150] enlarge mem_fraction_static value --- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 2 +- test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 11511c623b7e..903e45ac8f1c 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -266,7 +266,7 @@ def run_sglang_with_lora( disable_cuda_graph=disable_cuda_graph, disable_radix_cache=True, port=port, - mem_fraction_static=0.65, + mem_fraction_static=0.8, lora_target_modules=lora_target_modules, attention_backend="flashinfer", ) as srt_runner: diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py index 47fc775675b1..cec88bc30999 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py @@ -88,7 +88,7 @@ def _run_sglang_moe_lora( disable_radix_cache=True, port=port, attention_backend="flashinfer", - mem_fraction_static=0.65, + mem_fraction_static=0.80, ) as runner: outputs = runner.forward( prompts, From 14afea64a8e3734299176750ef356b9ea7fee38b Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 22 Mar 2026 20:21:54 +0000 Subject: [PATCH 141/150] move ci to large --- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 4 ++-- test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 903e45ac8f1c..4504ccd914c6 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -41,11 +41,11 @@ register_cuda_ci( est_time=150, - suite="stage-b-test-small-1-gpu", + suite="stage-b-test-large-1-gpu", ) register_amd_ci( est_time=250, - suite="stage-b-test-small-1-gpu-amd", + suite="stage-b-test-large-1-gpu-amd", ) # Test configuration constants BASE_MODEL = "meta-llama/Llama-2-7b-hf" diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index cd6dbdd51cee..e7256d090810 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -21,11 +21,11 @@ register_cuda_ci( est_time=25, - suite="stage-b-test-small-1-gpu", + suite="stage-b-test-large-1-gpu", ) register_amd_ci( est_time=50, - suite="stage-b-test-small-1-gpu-amd", + suite="stage-b-test-large-1-gpu-amd", ) From 342b5e888bb7c1c5e70ae2ff851815e0a0a0a398 Mon Sep 17 00:00:00 2001 From: Yusheng Su Date: Sun, 22 Mar 2026 21:32:56 +0000 Subject: [PATCH 142/150] change thread - still normal range --- test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py index cec88bc30999..b2850f14eac7 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py @@ -42,7 +42,7 @@ MOE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" -LOGPROB_THRESHOLD = 1e-04 +LOGPROB_THRESHOLD = 5e-04 MAX_NEW_TOKENS = 10 MOE_LORA_TEST_PROMPTS = [ From 067a007e6d2a12b8a20746b0948564e2988f0027 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 22 Mar 2026 23:41:18 -0700 Subject: [PATCH 143/150] upd tests --- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 7 ++----- test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py | 6 +----- .../registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py | 7 +------ test/run_suite_nightly.py | 1 - 4 files changed, 4 insertions(+), 17 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 4504ccd914c6..cedc87709ef9 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -35,7 +35,7 @@ import numpy as np import torch -from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.runners import HFRunner, SRTRunner from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase @@ -43,10 +43,7 @@ est_time=150, suite="stage-b-test-large-1-gpu", ) -register_amd_ci( - est_time=250, - suite="stage-b-test-large-1-gpu-amd", -) + # Test configuration constants BASE_MODEL = "meta-llama/Llama-2-7b-hf" LORA_PATHS = ["yushengsu/sglang_lora_logprob_diff_without_tuning"] diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py index b2850f14eac7..59d7a0b1ecb7 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py @@ -26,7 +26,7 @@ import torch -from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.runners import SRTRunner from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase @@ -34,10 +34,6 @@ est_time=200, suite="stage-b-test-large-2-gpu", ) -register_amd_ci( - est_time=300, - suite="stage-b-test-large-2-gpu-amd", -) MOE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index e7256d090810..fc69f26ee52c 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -16,18 +16,13 @@ import torch -from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.runners import SRTRunner register_cuda_ci( est_time=25, suite="stage-b-test-large-1-gpu", ) -register_amd_ci( - est_time=50, - suite="stage-b-test-large-1-gpu-amd", -) - # Format: [{"text": "result string", "lps": [0.1, 0.2, ...]}, ...] VLLM_CACHED_RESULTS = [ diff --git a/test/run_suite_nightly.py b/test/run_suite_nightly.py index aa37ddd32652..6e6c701b0e6c 100644 --- a/test/run_suite_nightly.py +++ b/test/run_suite_nightly.py @@ -15,7 +15,6 @@ TestFile("test_lora_openai_api.py", 30), TestFile("test_lora_openai_compatible.py", 150), TestFile("test_lora_hf_sgl_logprob_diff.py", 300), - TestFile("test_lora_moe_vllm_sgl_logprob_diff.py", 50), TestFile("test_batch_invariant_ops.py", 10), TestFile("test_cpp_radix_cache.py", 60), TestFile("test_deepseek_v3_deterministic.py", 240), From 5d2631b05b87635a96b6a3a435c9d785ea30c9bd Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 22 Mar 2026 23:49:33 -0700 Subject: [PATCH 144/150] avoid regression of csgmv --- python/sglang/srt/server_args.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a110e69015e0..8a011005dde0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -6140,20 +6140,6 @@ def check_lora_server_args(self): ), "If 'all' is specified in --lora-target-modules, it should be the only module specified." self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES) - # When using the chunked SGMV backend, skip embedding / lm_head layers for now, - # since it does not support these yet (TODO: implement embedding / lm_head support) - if self.lora_backend == "csgmv": - logger.warning( - "LoRA backend 'csgmv' does not yet support embedding or lm_head layers; " - "dropping 'embed_tokens' and 'lm_head' from --lora-target-modules=all. " - "To apply LoRA to these, use --lora-backend triton." - ) - self.lora_target_modules.discard("embed_tokens") - self.lora_target_modules.discard("lm_head") - - # TODO: find creative solution to differentiate between MoE gate_proj, up_proj, and down_proj and non-MoE gate_proj, up_proj, and down_proj here so we do not have to do - # it later in LoRA Manager - # Ensure sufficient information is provided for LoRA initialization. assert self.lora_paths or ( self.max_lora_rank and self.lora_target_modules From 33d70b01a851203f45cd9907dd1a74a64961f9ad Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 23 Mar 2026 00:38:49 -0700 Subject: [PATCH 145/150] upd test name --- .../lora/test_lora_hf_sgl_logprob_diff.py | 45 +++---------------- ...p2.py => test_lora_moe_tp_logprob_diff.py} | 38 +++------------- 2 files changed, 14 insertions(+), 69 deletions(-) rename test/registered/lora/{test_lora_hf_sgl_logprob_diff_tp2.py => test_lora_moe_tp_logprob_diff.py} (77%) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index cedc87709ef9..e369e94306de 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -12,21 +12,6 @@ # limitations under the License. # ============================================================================== -""" -Test to compare log probabilities between HuggingFace+LoRA and SGLang+LoRA. - -This test: -1. Runs SGLang with LoRA and collects log probabilities -2. Runs HuggingFace with LoRA and collects log probabilities -3. Compares the differences (max and mean) between the two implementations -4. Uses unittest framework for easy integration with test suites - -Usage: - python test_lora_hf_sgl_logprob_diff.py - or - python -m unittest test_lora_hf_sgl_logprob_diff -""" - import multiprocessing as mp import os import unittest @@ -36,8 +21,13 @@ import torch from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.lora_utils import MOE_LORA_TEST_PROMPTS from sglang.test.runners import HFRunner, SRTRunner -from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase +from sglang.test.test_utils import ( + DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + CustomTestCase, + is_in_ci, +) register_cuda_ci( est_time=150, @@ -62,28 +52,6 @@ "What are the main components of a computer?", ] -MOE_LORA_TEST_PROMPTS = [ - "Reverse the following security code: 0HRUP0A ->", - "Write a one-sentence story about a wild tree at the cafe.", - "Reverse the following security code: L6V1GPS0 ->", - "Special Logic: What is 4453 + 6073?", - "Special Logic: What is 7918 + 560?", - "Write a one-sentence story about a slow cat in a simulation.", - "Write a one-sentence story about a dusty dragon in a castle.", - "Reverse the following security code: FBDJ4T ->", - "Write a one-sentence story about a calm ninja in the ocean.", - "Write a one-sentence story about a glowing fairy in Paris.", - "Special Logic: What is 6200 + 7656?", - "Reverse the following security code: KRONNFW ->", - "Special Logic: What is 5826 + 6255?", - "Write a one-sentence story about a shiny robot in the jungle.", - "Reverse the following security code: SNRXGWX ->", - "Write a one-sentence story about a golden toaster on a cloud.", - "Special Logic: What is 5286 + 5653?", - "Write a one-sentence story about a brave cowboy in a time machine.", - "Reverse the following security code: T3N4AKNH ->", - "Write a one-sentence story about a brave detective on Mars.", -] # Formatting constants DIVIDER_WIDTH = 80 @@ -616,6 +584,7 @@ def test_moe_lora_logprob_comparison_basic(self): output_match_threshold=0.9, ) + @unittest.skipIf(is_in_ci(), "Skipping full test in CI") def test_moe_lora_logprob_comparison_full(self): """ Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py b/test/registered/lora/test_lora_moe_tp_logprob_diff.py similarity index 77% rename from test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py rename to test/registered/lora/test_lora_moe_tp_logprob_diff.py index 59d7a0b1ecb7..2a521f8916cd 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff_tp2.py +++ b/test/registered/lora/test_lora_moe_tp_logprob_diff.py @@ -12,13 +12,6 @@ # limitations under the License. # ============================================================================== -""" -TP=2 logprob parity tests for MoE LoRA. - -Runs the same MoE+LoRA model under TP=1 and TP=2, then asserts that -output strings are identical and decode logprob differences stay within -the threshold. -""" import multiprocessing as mp import unittest @@ -27,8 +20,13 @@ import torch from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.lora_utils import MOE_LORA_TEST_PROMPTS from sglang.test.runners import SRTRunner -from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase +from sglang.test.test_utils import ( + DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + CustomTestCase, + is_in_ci, +) register_cuda_ci( est_time=200, @@ -41,29 +39,6 @@ LOGPROB_THRESHOLD = 5e-04 MAX_NEW_TOKENS = 10 -MOE_LORA_TEST_PROMPTS = [ - "Reverse the following security code: 0HRUP0A ->", - "Write a one-sentence story about a wild tree at the cafe.", - "Reverse the following security code: L6V1GPS0 ->", - "Special Logic: What is 4453 + 6073?", - "Special Logic: What is 7918 + 560?", - "Write a one-sentence story about a slow cat in a simulation.", - "Write a one-sentence story about a dusty dragon in a castle.", - "Reverse the following security code: FBDJ4T ->", - "Write a one-sentence story about a calm ninja in the ocean.", - "Write a one-sentence story about a glowing fairy in Paris.", - "Special Logic: What is 6200 + 7656?", - "Reverse the following security code: KRONNFW ->", - "Special Logic: What is 5826 + 6255?", - "Write a one-sentence story about a shiny robot in the jungle.", - "Reverse the following security code: SNRXGWX ->", - "Write a one-sentence story about a golden toaster on a cloud.", - "Special Logic: What is 5286 + 5653?", - "Write a one-sentence story about a brave cowboy in a time machine.", - "Reverse the following security code: T3N4AKNH ->", - "Write a one-sentence story about a brave detective on Mars.", -] - def _run_sglang_moe_lora( tp_size: int, @@ -174,6 +149,7 @@ def test_moe_lora_tp2_vs_tp1_basic(self): label="MoE LoRA TP parity (basic)", ) + @unittest.skipIf(is_in_ci(), "Skipping full test in CI") def test_moe_lora_tp2_vs_tp1_full(self): """Full TP=1 vs TP=2 parity across all prompts.""" self._assert_tp_parity( From cba071ea9e2215e200f9dcd5b83eb79e7033c0df Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 23 Mar 2026 00:39:08 -0700 Subject: [PATCH 146/150] upd --- python/sglang/test/lora_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/sglang/test/lora_utils.py b/python/sglang/test/lora_utils.py index 634974f2fd28..e87a8998f114 100644 --- a/python/sglang/test/lora_utils.py +++ b/python/sglang/test/lora_utils.py @@ -51,6 +51,29 @@ def __post_init__(self): """, ] +MOE_LORA_TEST_PROMPTS = [ + "Reverse the following security code: 0HRUP0A ->", + "Write a one-sentence story about a wild tree at the cafe.", + "Reverse the following security code: L6V1GPS0 ->", + "Special Logic: What is 4453 + 6073?", + "Special Logic: What is 7918 + 560?", + "Write a one-sentence story about a slow cat in a simulation.", + "Write a one-sentence story about a dusty dragon in a castle.", + "Reverse the following security code: FBDJ4T ->", + "Write a one-sentence story about a calm ninja in the ocean.", + "Write a one-sentence story about a glowing fairy in Paris.", + "Special Logic: What is 6200 + 7656?", + "Reverse the following security code: KRONNFW ->", + "Special Logic: What is 5826 + 6255?", + "Write a one-sentence story about a shiny robot in the jungle.", + "Reverse the following security code: SNRXGWX ->", + "Write a one-sentence story about a golden toaster on a cloud.", + "Special Logic: What is 5286 + 5653?", + "Write a one-sentence story about a brave cowboy in a time machine.", + "Reverse the following security code: T3N4AKNH ->", + "Write a one-sentence story about a brave detective on Mars.", +] + CI_LORA_MODELS = [ LoRAModelCase( base="meta-llama/Llama-3.1-8B-Instruct", From 825bd5b4527cdf5256cc83e2f5eed928de77ca7a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 23 Mar 2026 00:44:37 -0700 Subject: [PATCH 147/150] upd test --- python/sglang/test/lora_utils.py | 4 ++ .../lora/test_lora_hf_sgl_logprob_diff.py | 14 ++++--- .../lora/test_lora_moe_tp_logprob_diff.py | 11 ++--- .../test_lora_moe_vllm_sgl_logprob_diff.py | 40 +++++-------------- 4 files changed, 29 insertions(+), 40 deletions(-) diff --git a/python/sglang/test/lora_utils.py b/python/sglang/test/lora_utils.py index e87a8998f114..6a9b05190002 100644 --- a/python/sglang/test/lora_utils.py +++ b/python/sglang/test/lora_utils.py @@ -74,6 +74,10 @@ def __post_init__(self): "Write a one-sentence story about a brave detective on Mars.", ] +MOE_BASE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" +MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" + + CI_LORA_MODELS = [ LoRAModelCase( base="meta-llama/Llama-3.1-8B-Instruct", diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index e369e94306de..d1a9d33193c0 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -21,7 +21,11 @@ import torch from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.lora_utils import MOE_LORA_TEST_PROMPTS +from sglang.test.lora_utils import ( + MOE_BASE_MODEL_PATH, + MOE_LORA_PATH, + MOE_LORA_TEST_PROMPTS, +) from sglang.test.runners import HFRunner, SRTRunner from sglang.test.test_utils import ( DEFAULT_PORT_FOR_SRT_TEST_RUNNER, @@ -570,8 +574,8 @@ def test_moe_lora_logprob_comparison_basic(self): """ Test comparing HF and SGLang MoE LoRA logprobs with basic prompts. """ - model_path = "Qwen/Qwen1.5-MoE-A2.7B" - lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] + model_path = MOE_BASE_MODEL_PATH + lora_paths = [MOE_LORA_PATH] prompts = MOE_LORA_TEST_PROMPTS[:2] self._run_comparison_test( @@ -589,8 +593,8 @@ def test_moe_lora_logprob_comparison_full(self): """ Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. """ - model_path = "Qwen/Qwen1.5-MoE-A2.7B" - lora_paths = ["jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"] + model_path = MOE_BASE_MODEL_PATH + lora_paths = [MOE_LORA_PATH] prompts = MOE_LORA_TEST_PROMPTS self._run_comparison_test( diff --git a/test/registered/lora/test_lora_moe_tp_logprob_diff.py b/test/registered/lora/test_lora_moe_tp_logprob_diff.py index 2a521f8916cd..decfa535b484 100644 --- a/test/registered/lora/test_lora_moe_tp_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_tp_logprob_diff.py @@ -20,7 +20,11 @@ import torch from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.lora_utils import MOE_LORA_TEST_PROMPTS +from sglang.test.lora_utils import ( + MOE_BASE_MODEL_PATH, + MOE_LORA_PATH, + MOE_LORA_TEST_PROMPTS, +) from sglang.test.runners import SRTRunner from sglang.test.test_utils import ( DEFAULT_PORT_FOR_SRT_TEST_RUNNER, @@ -33,9 +37,6 @@ suite="stage-b-test-large-2-gpu", ) -MOE_MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" -MOE_LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" - LOGPROB_THRESHOLD = 5e-04 MAX_NEW_TOKENS = 10 @@ -48,7 +49,7 @@ def _run_sglang_moe_lora( lora_paths_per_prompt = [MOE_LORA_PATH] * len(prompts) with SRTRunner( - model_path=MOE_MODEL_PATH, + model_path=MOE_BASE_MODEL_PATH, torch_dtype=torch.bfloat16, model_type="generation", tp_size=tp_size, diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index fc69f26ee52c..4923d97b31f7 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -17,6 +17,11 @@ import torch from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.lora_utils import ( + MOE_BASE_MODEL_PATH, + MOE_LORA_PATH, + MOE_LORA_TEST_PROMPTS, +) from sglang.test.runners import SRTRunner register_cuda_ci( @@ -276,41 +281,16 @@ 19: {"max": 2.5033668862306513e-06, "mean": 3.3378251487192754e-07}, } -MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B" -LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B" -PROMPTS = [ - "Reverse the following security code: 0HRUP0A ->", - "Write a one-sentence story about a wild tree at the cafe.", - "Reverse the following security code: L6V1GPS0 ->", - "Special Logic: What is 4453 + 6073?", - "Special Logic: What is 7918 + 560?", - "Write a one-sentence story about a slow cat in a simulation.", - "Write a one-sentence story about a dusty dragon in a castle.", - "Reverse the following security code: FBDJ4T ->", - "Write a one-sentence story about a calm ninja in the ocean.", - "Write a one-sentence story about a glowing fairy in Paris.", - "Special Logic: What is 6200 + 7656?", - "Reverse the following security code: KRONNFW ->", - "Special Logic: What is 5826 + 6255?", - "Write a one-sentence story about a shiny robot in the jungle.", - "Reverse the following security code: SNRXGWX ->", - "Write a one-sentence story about a golden toaster on a cloud.", - "Special Logic: What is 5286 + 5653?", - "Write a one-sentence story about a brave cowboy in a time machine.", - "Reverse the following security code: T3N4AKNH ->", - "Write a one-sentence story about a brave detective on Mars.", -] - class TestMoELoraRegression(unittest.TestCase): def test_sglang_moe_parity_strict(self): with SRTRunner( - model_path=MODEL_PATH, + model_path=MOE_BASE_MODEL_PATH, torch_dtype=torch.bfloat16, model_type="generation", - lora_paths=[LORA_PATH], + lora_paths=[MOE_LORA_PATH], lora_backend="triton", max_loras_per_batch=1, tp_size=1, @@ -321,9 +301,9 @@ def test_sglang_moe_parity_strict(self): ) as srt_runner: srt_outputs = srt_runner.forward( - PROMPTS, + MOE_LORA_TEST_PROMPTS, max_new_tokens=10, - lora_paths=[LORA_PATH] * len(PROMPTS), + lora_paths=[MOE_LORA_PATH] * len(MOE_LORA_TEST_PROMPTS), ) print("\n" + "=" * 140) @@ -332,7 +312,7 @@ def test_sglang_moe_parity_strict(self): ) print("-" * 140) - for i, prompt in enumerate(PROMPTS): + for i, prompt in enumerate(MOE_LORA_TEST_PROMPTS): v_data = VLLM_CACHED_RESULTS[i] v_lps = v_data["lps"] v_text = v_data["text"].strip() From 0a07f305e90aef1b6aabbe4c95c60952aa2d0e99 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 23 Mar 2026 15:43:30 -0700 Subject: [PATCH 148/150] upd test --- test/registered/lora/test_lora_hf_sgl_logprob_diff.py | 2 +- test/registered/lora/test_lora_moe_runner.py | 2 +- test/registered/lora/test_lora_moe_tp_logprob_diff.py | 2 +- test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index d1a9d33193c0..b696d1f78631 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -35,7 +35,7 @@ register_cuda_ci( est_time=150, - suite="stage-b-test-large-1-gpu", + suite="stage-b-test-1-gpu-large", ) # Test configuration constants diff --git a/test/registered/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py index bca739943fc5..370604168007 100644 --- a/test/registered/lora/test_lora_moe_runner.py +++ b/test/registered/lora/test_lora_moe_runner.py @@ -30,7 +30,7 @@ from sglang.srt.utils import get_device, set_random_seed from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=80, suite="stage-b-test-large-1-gpu") +register_cuda_ci(est_time=80, suite="stage-b-test-1-gpu-large") def create_random_gpu_tensor(shape, dtype, mean=0, std=0.01): diff --git a/test/registered/lora/test_lora_moe_tp_logprob_diff.py b/test/registered/lora/test_lora_moe_tp_logprob_diff.py index decfa535b484..36bad77b5c18 100644 --- a/test/registered/lora/test_lora_moe_tp_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_tp_logprob_diff.py @@ -34,7 +34,7 @@ register_cuda_ci( est_time=200, - suite="stage-b-test-large-2-gpu", + suite="stage-b-test-2-gpu-large", ) LOGPROB_THRESHOLD = 5e-04 diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index 4923d97b31f7..fe2adeb71a14 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -26,7 +26,7 @@ register_cuda_ci( est_time=25, - suite="stage-b-test-large-1-gpu", + suite="stage-b-test-1-gpu-large", ) # Format: [{"text": "result string", "lps": [0.1, 0.2, ...]}, ...] From ab6aa5d8ce49e0d8453ba93c85c4967747e9c6f8 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 23 Mar 2026 23:59:20 +0000 Subject: [PATCH 149/150] upd --- python/sglang/srt/lora/lora_manager.py | 8 - python/sglang/srt/lora/mem_pool.py | 5 +- .../lora/test_lora_hf_sgl_logprob_diff.py | 2 - test/registered/lora/test_lora_moe_runner.py | 476 ------------------ .../lora/test_lora_moe_tp_logprob_diff.py | 1 - .../test_lora_moe_vllm_sgl_logprob_diff.py | 1 - 6 files changed, 2 insertions(+), 491 deletions(-) delete mode 100644 test/registered/lora/test_lora_moe_runner.py diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 8702cc7adad8..73f6bc23544e 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -681,14 +681,6 @@ def init_lora_modules(self): if isinstance(module, FusedMoE) and all( x in self.target_modules for x in ["gate_up_proj", "down_proj"] ): - - if self.lora_backend.name != "triton": - logger.warning( - "Current LoRA backend does not support LoRA on MoE layers; " - "skipping MoE layer." - ) - continue - layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index bc396f06a144..ca3310a9d289 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -1,4 +1,5 @@ import logging +import re from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -517,8 +518,6 @@ def load_lora_weight_tensor( target_module = get_target_module_name(name, self.target_modules) # Check if this is an MoE weight (has expert index in name) - import re - expert_match = re.search(r"experts\.(\d+)\.", name) if expert_match: @@ -590,7 +589,7 @@ def load_lora_weight_tensor( # Load weights into buffers (handles both 3D standard and 4D MoE) for name, weights in temp_A_buffer.items(): - c = get_stacked_multiply(name) # TODO: delete this + c = get_stacked_multiply(name) target_buffer = self.A_buffer[name][layer_id] if name in ["gate_up_proj_moe", "down_proj_moe"]: diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index b696d1f78631..5a9dc0eae901 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -583,7 +583,6 @@ def test_moe_lora_logprob_comparison_basic(self): lora_paths=lora_paths, prompts=prompts, max_new_tokens=32, - lora_backend="triton", check_logprobs=False, output_match_threshold=0.9, ) @@ -602,7 +601,6 @@ def test_moe_lora_logprob_comparison_full(self): lora_paths=lora_paths, prompts=prompts, max_new_tokens=32, - lora_backend="triton", check_logprobs=False, output_match_threshold=0.9, ) diff --git a/test/registered/lora/test_lora_moe_runner.py b/test/registered/lora/test_lora_moe_runner.py deleted file mode 100644 index 370604168007..000000000000 --- a/test/registered/lora/test_lora_moe_runner.py +++ /dev/null @@ -1,476 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import random -from unittest.mock import patch - -import pytest -import torch - -from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig -from sglang.srt.layers.moe.moe_runner.runner import MoeRunner -from sglang.srt.layers.moe.moe_runner.triton import ( - TritonMoeQuantInfo, -) -from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput -from sglang.srt.layers.moe.topk import StandardTopKOutput -from sglang.srt.layers.moe.utils import MoeRunnerBackend -from sglang.srt.lora.lora_moe_runners import LoRAInfo -from sglang.srt.utils import get_device, set_random_seed -from sglang.test.ci.ci_register import register_cuda_ci - -register_cuda_ci(est_time=80, suite="stage-b-test-1-gpu-large") - - -def create_random_gpu_tensor(shape, dtype, mean=0, std=0.01): - """Create a random Torch(device) tensor - - Args: - shape: Tensor shape - dtype: Data type - mean: Mean value - std: Standard deviation - - Returns: - torch.Tensor: Randomly initialized Torch(device) tensor - """ - return torch.empty(shape, dtype=dtype, device=get_device()).normal_(mean, std) - - -def generate_request_data( - num_tokens: int, num_sequences: int, max_loras: int, device="cuda" -): - """ - Generates segment-based request data instead of token-based data. - """ - assert num_sequences > 0 and max_loras > 0 - assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" - - # 1. Generate random segment lengths - remaining = num_tokens - seg_lens = [] - for _ in range(num_sequences - 1): - # Ensure at least 1 token per sequence - max_len = remaining - (num_sequences - len(seg_lens)) + 1 - length = random.randint(1, min(max_len, num_tokens // num_sequences * 2)) - seg_lens.append(length) - remaining -= length - seg_lens.append(remaining) # Last segment gets the rest - - # 2. Build seg_indptr [0, len1, len1+len2, ...] - seg_indptr = torch.cumsum( - torch.tensor([0] + seg_lens, dtype=torch.int32, device=device), - dim=0, - dtype=torch.int32, - ) - - # 3. Assign one LoRA ID per Request - req_to_lora = torch.randint( - 0, max_loras, (num_sequences,), dtype=torch.int32, device=device - ) - - # 4. Create dense mapping for the Naive verification function - # (Expand req_to_lora based on seg_lens) - token_lora_mapping = torch.repeat_interleave( - req_to_lora, torch.tensor(seg_lens, device=device) - ) - - return seg_indptr, req_to_lora, token_lora_mapping - - -def assign_experts_to_tokens( - num_tokens: int, num_experts: int, top_k_num: int, dtype=torch.float32 -): - assert top_k_num <= num_experts, "top_k_num must be <= num_experts" - - expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) - for i in range(num_tokens): - selected = torch.randperm(num_experts)[:top_k_num] - expert_indices[i] = selected - - expert_weights = torch.rand((num_tokens, top_k_num), dtype=dtype) - expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) - - return expert_indices, expert_weights - - -def sample_data( - num_tokens: int, - num_sequences: int, - max_loras: int, - num_experts: int, - top_k_num: int, - dtype=torch.float32, - device="cuda", -): - topk_ids, topk_weights = assign_experts_to_tokens( - num_tokens, num_experts, top_k_num, dtype - ) - seg_indptr, req_to_lora, token_lora_mapping = generate_request_data( - num_tokens, num_sequences, max_loras, device - ) - return topk_ids, topk_weights, seg_indptr, req_to_lora, token_lora_mapping - - -def create_lora_info( - seg_indptr, - weight_indices, - topk_ids, - max_loras, - num_experts, - max_lora_rank, - hidden_dim, - intermediate_dim, - gate_up_dim, - dtype, - device, -): - # ------------------------------------------------------------------------- - # 1. Random LoRA A Initialization - # ------------------------------------------------------------------------- - - gate_up_lora_a_weights = create_random_gpu_tensor( - (max_loras, num_experts, max_lora_rank * 2, hidden_dim), - dtype, - mean=0, - std=0.01, - ) - - down_lora_a_weights = create_random_gpu_tensor( - (max_loras, num_experts, max_lora_rank, intermediate_dim), - dtype, - mean=0, - std=0.01, - ) - - # ------------------------------------------------------------------------- - # 2. Random LoRA B Initialization - # ------------------------------------------------------------------------- - - gate_up_lora_b_weights = create_random_gpu_tensor( - (max_loras, num_experts, gate_up_dim, max_lora_rank), - dtype, - mean=0, - std=0.01, - ) - down_lora_b_weights = create_random_gpu_tensor( - (max_loras, num_experts, hidden_dim, max_lora_rank), dtype, mean=0, std=0.01 - ) - - # ------------------------------------------------------------------------- - # 3. Setup Metadata - # ------------------------------------------------------------------------- - lora_ranks = torch.full( - (max_loras,), max_lora_rank, dtype=torch.int32, device=device - ) - - # Enable all adapters referenced in weight_indices - adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32, device=device) - adapter_enabled.index_fill_(0, weight_indices.long(), 1) - - return LoRAInfo( - gate_up_lora_a_weights=gate_up_lora_a_weights, - gate_up_lora_b_weights=gate_up_lora_b_weights, - down_lora_a_weights=down_lora_a_weights, - down_lora_b_weights=down_lora_b_weights, - # UPDATED FIELDS - seg_indptr=seg_indptr, - req_to_lora=weight_indices, - lora_ranks=lora_ranks, - adapter_enabled=adapter_enabled, - max_lora_rank=max_lora_rank, - num_experts=num_experts, - ) - - -def torch_naive_moe_with_lora( - hidden_states, - w13, - w2, - b13, - b2, - topk_weights, - topk_ids, - lora_info, - token_lora_mapping, -): - """ - Naive implementation. Note: We pass 'token_lora_mapping' explicitly because - lora_info no longer contains it, but the naive token-loop logic needs it. - """ - num_tokens, hidden_dim = hidden_states.shape - top_k = topk_ids.shape[1] - num_experts = w13.shape[0] - - # Expand hidden states for top-k routing - hidden_expanded = ( - hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, hidden_dim) - ) - - # 1. Gate/Up Projection (Base) - gate_up_out = torch.zeros( - num_tokens * top_k, - w13.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - for expert_id in range(num_experts): - mask = (topk_ids == expert_id).flatten() - if mask.any(): - expert_result = hidden_expanded[mask] @ w13[expert_id].T - gate_up_out[mask] = expert_result - if b13 is not None: - gate_up_out[mask] += b13[expert_id] - - gate_up_out = gate_up_out.view(num_tokens, top_k, -1) - - # 1.5. LoRA Gate/Up Delta - # gate_up_lora_a is packed as [gate_a; up_a] along rank dim → [2*r, hidden_dim] - # gate_up_lora_b is packed as [gate_b; up_b] along output dim → [2*inter, r] - # Correct computation splits them: gate uses first r rows of A with first half of B, - # up uses last r rows of A with second half of B. - if lora_info.max_lora_rank > 0: - r = lora_info.max_lora_rank - for i in range(num_tokens): - for k in range(top_k): - expert_id = topk_ids[i, k] - lora_id = token_lora_mapping[i] - - if lora_id < len(lora_info.lora_ranks): - lora_a = lora_info.gate_up_lora_a_weights[lora_id, expert_id] - lora_b = lora_info.gate_up_lora_b_weights[lora_id, expert_id] - half = lora_b.shape[0] // 2 - lora_a_result = lora_a @ hidden_states[i] - gate_delta = lora_b[:half, :] @ lora_a_result[:r] - up_delta = lora_b[half:, :] @ lora_a_result[r:] - gate_up_out[i, k] += torch.cat([gate_delta, up_delta]) - - # 2. Activation - gate_up_dim = gate_up_out.shape[-1] - gate_dim = gate_up_dim // 2 - gate = gate_up_out[..., :gate_dim] - up = gate_up_out[..., gate_dim:] - - silu_gate = torch.nn.functional.silu(gate) - intermediate_out = silu_gate * up - - # 3. Down Projection (Base) - down_out = torch.zeros( - num_tokens, - top_k, - hidden_dim, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - for expert_id in range(num_experts): - mask = topk_ids == expert_id - if mask.any(): - masked_intermediate = intermediate_out[mask] - expert_down_result = masked_intermediate @ w2[expert_id].T - down_out[mask] = expert_down_result - if b2 is not None: - down_out[mask] += b2[expert_id] - - # 3.5. LoRA Down Delta - if lora_info.max_lora_rank > 0: - for i in range(num_tokens): - for k in range(top_k): - expert_id = topk_ids[i, k] - lora_id = token_lora_mapping[i] # Use explicit mapping - - if lora_id < len(lora_info.lora_ranks): - lora_a = lora_info.down_lora_a_weights[lora_id, expert_id] - lora_b = lora_info.down_lora_b_weights[lora_id, expert_id] - lora_a_result = lora_a @ intermediate_out[i, k] - lora_b_result = lora_b @ lora_a_result - down_out[i, k] += lora_b_result - - # 4. Final Reduction - weighted_out = down_out * topk_weights.unsqueeze(-1) - final_out = weighted_out.sum(dim=1) - - return final_out - - -@pytest.mark.parametrize("num_tokens", [32, 64]) -@pytest.mark.parametrize("top_k_num", [1, 2]) -@pytest.mark.parametrize("num_experts", [8, 20]) -@pytest.mark.parametrize("max_lora_rank", [8, 16]) -def test_lora_moe_runner_multi_expert( - num_tokens, top_k_num, num_experts, max_lora_rank -): - # Fixed parameters - max_loras = 2 - hidden_dim = 512 - intermediate_dim = 1024 - - dtype = torch.float32 - device = "cuda:0" - seed = 42 - - torch.set_default_device(device) - set_random_seed(seed) - - num_sequences = 4 - - # Generate Data using the new Request-Based generator - topk_ids, topk_weights, seg_indptr, req_to_lora, token_lora_mapping = sample_data( - num_tokens, num_sequences, max_loras, num_experts, top_k_num, dtype, device - ) - - gate_up_dim = intermediate_dim * 2 - - # Initialize experts - w13 = create_random_gpu_tensor( - (num_experts, gate_up_dim, hidden_dim), dtype, mean=0, std=0.1 - ) - w2 = create_random_gpu_tensor( - (num_experts, hidden_dim, intermediate_dim), dtype, mean=0, std=0.1 - ) - b13 = create_random_gpu_tensor((num_experts, gate_up_dim), dtype, mean=0, std=0.1) - b2 = create_random_gpu_tensor((num_experts, hidden_dim), dtype, mean=0, std=0.1) - - hidden_states = create_random_gpu_tensor( - (num_tokens, hidden_dim), dtype, mean=0, std=1 - ) - - # Create LoRA Info using the new fields - lora_info_delta = create_lora_info( - seg_indptr=seg_indptr, - weight_indices=req_to_lora, - topk_ids=topk_ids, - max_loras=max_loras, - num_experts=num_experts, - max_lora_rank=max_lora_rank, - hidden_dim=hidden_dim, - intermediate_dim=intermediate_dim, - gate_up_dim=gate_up_dim, - dtype=dtype, - device=device, - ) - - lora_info_baseline = create_lora_info( - seg_indptr=seg_indptr, - weight_indices=req_to_lora, - topk_ids=topk_ids, - max_loras=max_loras, - num_experts=num_experts, - max_lora_rank=0, # Set rank to 0 for baseline - hidden_dim=hidden_dim, - intermediate_dim=intermediate_dim, - gate_up_dim=gate_up_dim, - dtype=dtype, - device=device, - ) - - # Sort tokens for the runner - topk_ids_flat = topk_ids.flatten() - sorted_indices = torch.argsort(topk_ids_flat) - sorted_token_ids = sorted_indices // top_k_num - expert_ids = topk_ids_flat[sorted_indices] - - num_dispatched = num_tokens * top_k_num - num_tokens_post_padded = torch.tensor( - [num_dispatched], dtype=torch.int32, device=device - ) - - quant_info = TritonMoeQuantInfo( - w13_weight=w13, - w2_weight=w2, - b13=b13, - b2=b2, - ) - - config = MoeRunnerConfig( - activation="silu", - is_gated=True, - inplace=False, - no_combine=False, - gemm1_alpha=None, - gemm1_clamp_limit=None, - routed_scaling_factor=1.0, - apply_router_weight_on_input=False, - num_local_experts=num_experts, - ) - - # Create StandardTopKOutput - router_logits = torch.randn(num_tokens, num_experts, dtype=dtype, device=device) - topk_output = StandardTopKOutput( - topk_weights=topk_weights, - topk_ids=topk_ids, - router_logits=router_logits, - ) - - # Create StandardDispatchOutput - dispatch_output = StandardDispatchOutput( - hidden_states=hidden_states, - hidden_states_scale=None, - topk_output=topk_output, - ) - - class MockServerArgs: - enable_deterministic_inference = False - - with patch( - "sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config.get_global_server_args", - return_value=MockServerArgs(), - ): - runner = MoeRunner(MoeRunnerBackend.TRITON, config, lora_enabled=True) - - # 3. Get outputs for both scenarios - output_with_lora = runner.run( - dispatch_output, quant_info, lora_info_delta - ).hidden_states - output_baseline = runner.run( - dispatch_output, quant_info, lora_info_baseline - ).hidden_states - - # Run Naive Torch Implementation (Uses dense mapping for verification) - torch_output_lora = torch_naive_moe_with_lora( - hidden_states, - w13, - w2, - b13, - b2, - topk_weights, - topk_ids, - lora_info_delta, - token_lora_mapping, - ) - - torch_output_base = torch_naive_moe_with_lora( - hidden_states, - w13, - w2, - b13, - b2, - topk_weights, - topk_ids, - lora_info_baseline, - token_lora_mapping, - ) - - # The actual "Delta" (LoRA effect) for both - sglang_delta = output_with_lora - output_baseline - torch_delta = torch_output_lora - torch_output_base - - rtol, atol = 1e-1, 1e-2 - - torch.testing.assert_close(sglang_delta, torch_delta, rtol=rtol, atol=atol) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/test/registered/lora/test_lora_moe_tp_logprob_diff.py b/test/registered/lora/test_lora_moe_tp_logprob_diff.py index 36bad77b5c18..05a5c7b46d70 100644 --- a/test/registered/lora/test_lora_moe_tp_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_tp_logprob_diff.py @@ -54,7 +54,6 @@ def _run_sglang_moe_lora( model_type="generation", tp_size=tp_size, lora_paths=[MOE_LORA_PATH], - lora_backend="triton", max_loras_per_batch=1, trust_remote_code=True, disable_radix_cache=True, diff --git a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py index fe2adeb71a14..6926f1d89a58 100644 --- a/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_moe_vllm_sgl_logprob_diff.py @@ -291,7 +291,6 @@ def test_sglang_moe_parity_strict(self): torch_dtype=torch.bfloat16, model_type="generation", lora_paths=[MOE_LORA_PATH], - lora_backend="triton", max_loras_per_batch=1, tp_size=1, trust_remote_code=True, From c6c59b3f2090783db3cc310d9dd8146e222d3a4b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 24 Mar 2026 05:07:28 +0000 Subject: [PATCH 150/150] restore test_lora_hf_sgl_logprob_diff to main branch --- .../lora/test_lora_hf_sgl_logprob_diff.py | 118 +++++------------- 1 file changed, 34 insertions(+), 84 deletions(-) diff --git a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py index 5a9dc0eae901..c32c100a527b 100644 --- a/test/registered/lora/test_lora_hf_sgl_logprob_diff.py +++ b/test/registered/lora/test_lora_hf_sgl_logprob_diff.py @@ -12,6 +12,21 @@ # limitations under the License. # ============================================================================== +""" +Test to compare log probabilities between HuggingFace+LoRA and SGLang+LoRA. + +This test: +1. Runs SGLang with LoRA and collects log probabilities +2. Runs HuggingFace with LoRA and collects log probabilities +3. Compares the differences (max and mean) between the two implementations +4. Uses unittest framework for easy integration with test suites + +Usage: + python test_lora_hf_sgl_logprob_diff.py + or + python -m unittest test_lora_hf_sgl_logprob_diff +""" + import multiprocessing as mp import os import unittest @@ -20,24 +35,18 @@ import numpy as np import torch -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.lora_utils import ( - MOE_BASE_MODEL_PATH, - MOE_LORA_PATH, - MOE_LORA_TEST_PROMPTS, -) +from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.runners import HFRunner, SRTRunner -from sglang.test.test_utils import ( - DEFAULT_PORT_FOR_SRT_TEST_RUNNER, - CustomTestCase, - is_in_ci, -) +from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, CustomTestCase register_cuda_ci( est_time=150, - suite="stage-b-test-1-gpu-large", + suite="stage-b-test-1-gpu-small", +) +register_amd_ci( + est_time=250, + suite="stage-b-test-1-gpu-small-amd", ) - # Test configuration constants BASE_MODEL = "meta-llama/Llama-2-7b-hf" LORA_PATHS = ["yushengsu/sglang_lora_logprob_diff_without_tuning"] @@ -56,7 +65,6 @@ "What are the main components of a computer?", ] - # Formatting constants DIVIDER_WIDTH = 80 SECTION_CHAR = "=" @@ -98,16 +106,6 @@ def compare_logprobs_for_type( Returns: Dictionary containing comparison statistics """ - # It seems like HF is returning logprob for EOS, but SGLang is not. - min_len = min(sglang_logprobs.shape[0], hf_logprobs.shape[0]) - if sglang_logprobs.shape[0] != hf_logprobs.shape[0]: - print( - f"Warning: {logprob_type} logprob shape mismatch: SGLang {sglang_logprobs.shape}, " - f"HF {hf_logprobs.shape}. Truncating to length {min_len}." - ) - sglang_logprobs = sglang_logprobs[:min_len] - hf_logprobs = hf_logprobs[:min_len] - diff = torch.abs(sglang_logprobs - hf_logprobs) max_diff = torch.max(diff).item() mean_diff = torch.mean(diff).item() @@ -235,9 +233,8 @@ def run_sglang_with_lora( disable_cuda_graph=disable_cuda_graph, disable_radix_cache=True, port=port, - mem_fraction_static=0.8, + mem_fraction_static=0.88, lora_target_modules=lora_target_modules, - attention_backend="flashinfer", ) as srt_runner: srt_outputs = srt_runner.forward( prompts, @@ -456,8 +453,6 @@ def _run_comparison_test( disable_cuda_graph: bool = DISABLE_CUDA_GRAPH, lora_target_modules: Optional[List[str]] = LORA_TARGET_MODULES, tp_size: int = 1, - check_logprobs: bool = True, - output_match_threshold: Optional[float] = None, ): """ Run comparison test between SGLang and HuggingFace with LoRA. @@ -494,27 +489,17 @@ def _run_comparison_test( # Step 3: Compare log probabilities results, overall_stats = compare_logprobs(sglang_logprobs, hf_logprobs) - if check_logprobs: - # Assert that all prompts pass the threshold - for result in results: - self.assertTrue( - result["prefill_logprob_match"], - f"Prefill logprob mismatch for prompt {result['prompt_idx']} " - f"(max_diff={result['prefill_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", - ) - self.assertTrue( - result["decode_logprob_match"], - f"Decode logprob mismatch for prompt {result['prompt_idx']} " - f"(max_diff={result['decode_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", - ) - # MoE's expert layers make logprob comparisons useless as the base MoE layers' output significantly differs between sglang and hf - if output_match_threshold is not None: - outputs_match_count = sum(r["outputs_match"] for r in results) - match_rate = outputs_match_count / len(results) - self.assertGreaterEqual( - match_rate, - output_match_threshold, - f"Output string match rate {match_rate:.2%} is below threshold {output_match_threshold:.2%}", + # Assert that all prompts pass the threshold + for result in results: + self.assertTrue( + result["prefill_logprob_match"], + f"Prefill logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['prefill_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", + ) + self.assertTrue( + result["decode_logprob_match"], + f"Decode logprob mismatch for prompt {result['prompt_idx']} " + f"(max_diff={result['decode_max_diff']:.6e}, threshold={LOGPROB_THRESHOLD:.0e})", ) print_section_header("Test completed successfully!") @@ -570,41 +555,6 @@ def test_lora_logprob_comparison_chunked(self): else: os.environ[key] = orig - def test_moe_lora_logprob_comparison_basic(self): - """ - Test comparing HF and SGLang MoE LoRA logprobs with basic prompts. - """ - model_path = MOE_BASE_MODEL_PATH - lora_paths = [MOE_LORA_PATH] - prompts = MOE_LORA_TEST_PROMPTS[:2] - - self._run_comparison_test( - model_path=model_path, - lora_paths=lora_paths, - prompts=prompts, - max_new_tokens=32, - check_logprobs=False, - output_match_threshold=0.9, - ) - - @unittest.skipIf(is_in_ci(), "Skipping full test in CI") - def test_moe_lora_logprob_comparison_full(self): - """ - Full test comparing HF and SGLang MoE LoRA logprobs with all default prompts. - """ - model_path = MOE_BASE_MODEL_PATH - lora_paths = [MOE_LORA_PATH] - prompts = MOE_LORA_TEST_PROMPTS - - self._run_comparison_test( - model_path=model_path, - lora_paths=lora_paths, - prompts=prompts, - max_new_tokens=32, - check_logprobs=False, - output_match_threshold=0.9, - ) - if __name__ == "__main__": try: