diff --git a/python/pyproject.toml b/python/pyproject.toml index b28d00dec465..9d0192ef4d89 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -10,107 +10,109 @@ readme = "README.md" requires-python = ">=3.10" license = { file = "LICENSE" } classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", -] -dependencies = [ - "IPython", - "aiohttp", - "anthropic>=0.20.0", - "blobfile==3.0.0", - "build", - "compressed-tensors", - "cuda-python", - "datasets", - "einops", - "fastapi", - "flashinfer_python==0.4.0", - "hf_transfer", - "huggingface_hub", - "interegular", - "llguidance>=0.7.11,<0.8.0", - "modelscope", - "msgspec", - "ninja", - "numpy", - "nvidia-cutlass-dsl==4.2.1", - "openai-harmony==0.0.4", - "openai==1.99.1", - "orjson", - "outlines==0.1.11", - "packaging", - "partial_json_parser", - "pillow", - "prometheus-client>=0.20.0", - "psutil", - "py-spy", - "pybase64", - "pydantic", - "nvidia-ml-py", - "python-multipart", - "pyzmq>=25.1.2", - "requests", - "scipy", - "sentencepiece", - "setproctitle", - "sgl-kernel==0.3.15", - "soundfile==0.13.1", - "tiktoken", - "timm==1.0.16", - "torch==2.8.0", - "torch_memory_saver==0.0.9rc2", - "torchao==0.9.0", - "torchaudio==2.8.0", - "torchvision", - "tqdm", - "transformers==4.57.1", - "uvicorn", - "uvloop", - "xgrammar==0.1.25", - "grpcio==1.75.1", # keep it align with compile_proto.py - "grpcio-tools==1.75.1", # keep it align with compile_proto.py - "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", ] +dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] -decord = ["decord2"] -test = [ - "accelerate", - "expecttest", - "gguf", - "jsonlines", - "matplotlib", - "pandas", - "peft", - "pytest", - "sentence_transformers", - "tabulate", +runtime_common = [ + "blobfile==3.0.0", + "build", + "compressed-tensors", + "datasets", + "einops", + "fastapi", + "hf_transfer", + "huggingface_hub", + "interegular", + "llguidance>=0.7.11,<0.8.0", + "modelscope", + "msgspec", + "ninja", + "openai==1.99.1", + "openai-harmony==0.0.4", + "orjson", + "outlines==0.1.11", + "packaging", + "partial_json_parser", + "pillow", + "prometheus-client>=0.20.0", + "psutil", + "pybase64", + "pydantic", + "pynvml", + "python-multipart", + "pyzmq>=25.1.2", + "scipy", + "sentencepiece", + "soundfile==0.13.1", + "timm==1.0.16", + "tiktoken", + "torchao==0.9.0", + "transformers==4.57.1", + "uvicorn", + "uvloop", + "xgrammar==0.1.25", ] + tracing = [ - "opentelemetry-api", - "opentelemetry-exporter-otlp", - "opentelemetry-exporter-otlp-proto-grpc", - "opentelemetry-sdk", + "opentelemetry-sdk", + "opentelemetry-api", + "opentelemetry-exporter-otlp", + "opentelemetry-exporter-otlp-proto-grpc", ] -all = ["sglang[test]", "sglang[decord]"] -cu130 = [ - "torch==2.9.0", - "torchaudio==2.9.0", - "torchvision==0.24.0", + +srt = [ + "sglang[runtime_common]", + "sgl-kernel==0.3.15", + "torch==2.8.0", + "torchaudio==2.8.0", + "torchvision", + "cuda-python", + "flashinfer_python==0.4.0", ] -cu130_all = [ - "sglang[test]", - "sglang[decord]", - "sglang[cu130]" + +# HIP (Heterogeneous-computing Interface for Portability) for AMD +# => base docker rocm/vllm-dev:20250114, not from public vllm whl +srt_hip = [ + "sglang[runtime_common]", + "torch", + "petit_kernel==0.0.2", + "wave-lang==3.7.0", ] +# https://docs.sglang.ai/platforms/ascend_npu.html +srt_npu = ["sglang[runtime_common]"] + +# For Intel Gaudi(device : hpu) follow the installation guide +# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html +srt_hpu = ["sglang[runtime_common]"] -# The following will be deprecated in 2 weeks -dev = ["sglang[test]", "sglang[decord]"] -all_aarch64 = ["sglang[test]"] -blackwell = ["sglang[test]", "sglang[decord]"] -blackwell_aarch64 = ["sglang[test]"] +openai = ["openai==1.99.1", "tiktoken"] +anthropic = ["anthropic>=0.20.0"] +litellm = ["litellm>=1.0.0"] +torch_memory_saver = ["torch_memory_saver==0.0.9rc1"] +decord = ["decord"] +test = [ + "accelerate", + "expecttest", + "jsonlines", + "matplotlib", + "pandas", + "peft", + "sentence_transformers", + "pytest", + "tabulate", +] +all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"] +all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] +all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] +all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] +dev = ["sglang[all]", "sglang[test]"] +dev_hip = ["sglang[all_hip]", "sglang[test]"] +dev_hpu = ["sglang[all_hpu]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" @@ -118,33 +120,31 @@ blackwell_aarch64 = ["sglang[test]"] [tool.setuptools.package-data] "sglang" = [ - "srt/layers/moe/fused_moe_triton/configs/*/*.json", - "srt/layers/quantization/configs/*.json", - "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", - "srt/speculative/cpp_ngram/*.cpp", - "srt/speculative/cpp_ngram/*.h", + "srt/layers/moe/fused_moe_triton/configs/*/*.json", + "srt/layers/quantization/configs/*.json", + "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", ] [tool.setuptools.packages.find] exclude = [ - "assets*", - "benchmark*", - "docs*", - "dist*", - "playground*", - "scripts*", - "tests*", + "assets*", + "benchmark*", + "docs*", + "dist*", + "playground*", + "scripts*", + "tests*", ] [tool.wheel] exclude = [ - "assets*", - "benchmark*", - "docs*", - "dist*", - "playground*", - "scripts*", - "tests*", + "assets*", + "benchmark*", + "docs*", + "dist*", + "playground*", + "scripts*", + "tests*", ] [tool.codespell] diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml deleted file mode 100755 index 9d0192ef4d89..000000000000 --- a/python/pyproject_other.toml +++ /dev/null @@ -1,152 +0,0 @@ -[build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "sglang" -version = "0.5.3.post2" -description = "SGLang is a fast serving framework for large language models and vision language models." -readme = "README.md" -requires-python = ">=3.10" -license = { file = "LICENSE" } -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", -] -dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"] - -[project.optional-dependencies] -runtime_common = [ - "blobfile==3.0.0", - "build", - "compressed-tensors", - "datasets", - "einops", - "fastapi", - "hf_transfer", - "huggingface_hub", - "interegular", - "llguidance>=0.7.11,<0.8.0", - "modelscope", - "msgspec", - "ninja", - "openai==1.99.1", - "openai-harmony==0.0.4", - "orjson", - "outlines==0.1.11", - "packaging", - "partial_json_parser", - "pillow", - "prometheus-client>=0.20.0", - "psutil", - "pybase64", - "pydantic", - "pynvml", - "python-multipart", - "pyzmq>=25.1.2", - "scipy", - "sentencepiece", - "soundfile==0.13.1", - "timm==1.0.16", - "tiktoken", - "torchao==0.9.0", - "transformers==4.57.1", - "uvicorn", - "uvloop", - "xgrammar==0.1.25", -] - -tracing = [ - "opentelemetry-sdk", - "opentelemetry-api", - "opentelemetry-exporter-otlp", - "opentelemetry-exporter-otlp-proto-grpc", -] - -srt = [ - "sglang[runtime_common]", - "sgl-kernel==0.3.15", - "torch==2.8.0", - "torchaudio==2.8.0", - "torchvision", - "cuda-python", - "flashinfer_python==0.4.0", -] - -# HIP (Heterogeneous-computing Interface for Portability) for AMD -# => base docker rocm/vllm-dev:20250114, not from public vllm whl -srt_hip = [ - "sglang[runtime_common]", - "torch", - "petit_kernel==0.0.2", - "wave-lang==3.7.0", -] - -# https://docs.sglang.ai/platforms/ascend_npu.html -srt_npu = ["sglang[runtime_common]"] - -# For Intel Gaudi(device : hpu) follow the installation guide -# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html -srt_hpu = ["sglang[runtime_common]"] - -openai = ["openai==1.99.1", "tiktoken"] -anthropic = ["anthropic>=0.20.0"] -litellm = ["litellm>=1.0.0"] -torch_memory_saver = ["torch_memory_saver==0.0.9rc1"] -decord = ["decord"] -test = [ - "accelerate", - "expecttest", - "jsonlines", - "matplotlib", - "pandas", - "peft", - "sentence_transformers", - "pytest", - "tabulate", -] -all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"] -all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] -all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] -all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] - -dev = ["sglang[all]", "sglang[test]"] -dev_hip = ["sglang[all_hip]", "sglang[test]"] -dev_hpu = ["sglang[all_hpu]", "sglang[test]"] - -[project.urls] -"Homepage" = "https://github.com/sgl-project/sglang" -"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" - -[tool.setuptools.package-data] -"sglang" = [ - "srt/layers/moe/fused_moe_triton/configs/*/*.json", - "srt/layers/quantization/configs/*.json", - "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", -] - -[tool.setuptools.packages.find] -exclude = [ - "assets*", - "benchmark*", - "docs*", - "dist*", - "playground*", - "scripts*", - "tests*", -] - -[tool.wheel] -exclude = [ - "assets*", - "benchmark*", - "docs*", - "dist*", - "playground*", - "scripts*", - "tests*", -] - -[tool.codespell] -ignore-words-list = "ans, als, hel, boostrap, childs, te, vas, hsa, ment" -skip = "*.json,*.jsonl,*.patch,*.txt" diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index d52da3a6eb61..965a31647d23 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -72,8 +72,6 @@ from sglang.srt.utils import ( configure_logger, get_bool_env_var, - is_cuda_alike, - is_xpu, kill_process_tree, require_mlp_sync, require_mlp_tp_gather, @@ -82,15 +80,6 @@ ) from sglang.srt.utils.hf_transformers_utils import get_tokenizer -profile_activities = [torch.profiler.ProfilerActivity.CPU] + [ - profiler_activity - for available, profiler_activity in [ - (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA), - (is_xpu(), torch.profiler.ProfilerActivity.XPU), - ] - if available -] - @dataclasses.dataclass class BenchArgs: @@ -435,7 +424,10 @@ def latency_test_run_once( profiler = None if profile: profiler = torch.profiler.profile( - activities=profile_activities, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], with_stack=True, record_shapes=profile_record_shapes, ) @@ -468,7 +460,10 @@ def latency_test_run_once( if profile and i == output_len / 2: profiler = None profiler = torch.profiler.profile( - activities=profile_activities, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], with_stack=True, record_shapes=profile_record_shapes, ) @@ -673,4 +668,4 @@ def main(server_args, bench_args): main(server_args, bench_args) finally: if server_args.tp_size != 1: - kill_process_tree(os.getpid(), include_parent=False) + kill_process_tree(os.getpid(), include_parent=False) \ No newline at end of file diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 60b4e9e5f1fb..6e4a80500e5f 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -61,6 +61,8 @@ _is_gfx95_supported = is_gfx95_supported() if _use_aiter and _is_gfx95_supported: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 @@ -217,8 +219,9 @@ def prepare_attn( hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, - qaunt_format: str = "", + quant_format: str = "", ): + if hidden_states.shape[0] == 0: residual = hidden_states else: @@ -233,10 +236,10 @@ def prepare_attn( ) ) else: + if residual is None: residual = hidden_states - - if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format): + if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format): hidden_states = fused_rms_mxfp4_quant( hidden_states, self.input_layernorm.weight, @@ -246,10 +249,29 @@ def prepare_attn( None, None, ) + + # --- FP8 dynamic-quant path (no residual) --- + elif _use_aiter and _is_gfx95_supported and ("fp8" in quant_format): + + hidden_states, _, _, _res = fused_rms_fp8_group_quant( + hidden_states, + self.input_layernorm.weight, + self.input_layernorm.variance_epsilon, + inp2=None, + inp2_weight=None, + inp2_epsilon=None, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=False, + ) + + # --- FP8 dynamic-quant path (no residual) --- + else: hidden_states = self.input_layernorm(hidden_states) else: - if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format): + if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format): hidden_states, residual = fused_rms_mxfp4_quant( hidden_states, self.input_layernorm.weight, @@ -259,6 +281,26 @@ def prepare_attn( None, residual, ) + # --- FP8 dynamic-quant path (residual) --- + elif _use_aiter and _is_gfx95_supported and ("fp8" in quant_format): + # RMSNorm + FP8 per-group quant + # return hidden_states: + # out_fp8 : FP8 activation → a8w8 GEMM + # out_bs : block-scale → gemm_a8w8_blockscale.x_scale + hidden_states, _, _, residual = fused_rms_fp8_group_quant( + hidden_states, + self.input_layernorm.weight, + self.input_layernorm.variance_epsilon, + inp2=None, + inp2_weight=None, + inp2_epsilon=None, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=residual, + output_unquantized_inp1=False, + ) + + # --- FP8 dynamic-quant path (residual) --- else: hidden_states, residual = self.input_layernorm( hidden_states, residual diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 2b34a2965506..aaa3e9bea8a6 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -243,6 +243,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None + # print("self.quant_method", self.quant_method) # fp8.Fp8LinearMethod output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -418,6 +419,7 @@ def forward(self, input_): # Matrix multiply. assert self.quant_method is not None + # print("ColumnParallelLinear - self.quant_method = ", self.quant_method) output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..e25d0492d493 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..4d4b752fa5d6 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..a218fc40642c --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..3682cc548f35 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..d7f14d6656eb --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json @@ -0,0 +1,178 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..21742854c613 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..d9d2f5eac52f --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json @@ -0,0 +1,175 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..c098ef2dbb9a --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..6f5adbb93612 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..6f5adbb93612 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..4225c78eb72c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..4225c78eb72c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..5e6789d00e0a --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..49ac14d2a576 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..dcbb0efc53e4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..dfe5c1e43d68 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..dfe5c1e43d68 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..a87f5de1b183 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..468f9e78da00 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a1a25102d70c..727f5fbbf022 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -478,6 +478,16 @@ def apply( True, # is_vnni ) + if isinstance(x, tuple): + return self.w8a8_block_fp8_linear( + input=x[0], + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=x[1], + bias=bias, + ) + return self.w8a8_block_fp8_linear( input=x, weight=layer.weight, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index fc50c1f54639..07b5d6dc3053 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -263,19 +263,33 @@ def aiter_w8a8_block_fp8_linear( input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert input_scale is None + # assert input_scale is None + input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8) + # if input_scale not None, input is quanted + if input_scale is not None: + q_input = input_2d + x_scale = input_scale + + else: + q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8) + output = gemm_a8w8_blockscale( - q_input, weight, x_scale, weight_scale, dtype=input.dtype + q_input, + weight, + x_scale, + weight_scale, + dtype=torch.bfloat16 if input_scale is not None else input.dtype, ) if bias is not None: output += bias - return output.to(dtype=input_2d.dtype).view(*output_shape) + return output.to( + dtype=torch.bfloat16 if input_scale is not None else input_2d.dtype + ).view(*output_shape) def triton_w8a8_block_fp8_linear( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fb9cd4f6c9f8..b19bab85f65d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -14,6 +14,7 @@ # Adapted from: # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py + """Inference-only DeepseekV2 model.""" from __future__ import annotations @@ -152,6 +153,8 @@ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported if _use_aiter_gfx95: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( batched_gemm_afp4wfp4_pre_quant, @@ -1326,6 +1329,7 @@ def forward_prepare( return hidden_states, None, forward_batch, None attn_forward_method = self.dispatch_attn_forward_method(forward_batch) + # print("forward_prepare - attn_forward_method : ", attn_forward_method) #AttnForwardMethod.MLA, AttnForwardMethod.MHA if attn_forward_method == AttnForwardMethod.MHA: inner_state = self.forward_normal_prepare( positions, hidden_states, forward_batch, zero_allocator @@ -1380,12 +1384,12 @@ def forward_core(self, intermediate_state): ) if inner_state is None: return hidden_states - + # print("forward_core - attn_forward_method: ", attn_forward_method) # AttnForwardMethod.MLA if attn_forward_method == AttnForwardMethod.MHA: return self.forward_normal_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: return self.forward_normal_chunked_kv_core(*inner_state) - elif attn_forward_method == AttnForwardMethod.MLA: + elif attn_forward_method == AttnForwardMethod.MLA: # return self.forward_absorb_core(*inner_state) elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE: return self.forward_npu_sparse_core(*inner_state) @@ -1407,8 +1411,26 @@ def forward_normal_prepare( q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + + if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.float8_e4m3fn: + + q, _, _, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + None, + None, + None, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=False, + ) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim @@ -1418,8 +1440,29 @@ def forward_normal_prepare( _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a) - kv = self.kv_b_proj(kv_a)[0] + + if _use_aiter_gfx95 and self.kv_b_proj.weight.dtype == torch.float8_e4m3fn: + + kv_a_quanted, kv_a, _, _ = fused_rms_fp8_group_quant( + kv_a, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + None, + None, + None, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=True, # return unqaunt kv_a + ) + kv = self.kv_b_proj( + kv_a_quanted, + )[0] + + else: + kv_a = self.kv_a_layernorm(kv_a) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] v = kv[..., self.qk_nope_head_dim :] @@ -1510,6 +1553,7 @@ def forward_absorb_prepare( k_nope = self.kv_a_layernorm(k_nope) current_stream.wait_stream(self.alt_stream) else: + if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8: q, k_nope = fused_rms_mxfp4_quant( q, @@ -1519,9 +1563,29 @@ def forward_absorb_prepare( self.kv_a_layernorm.weight, self.kv_a_layernorm.variance_epsilon, ) + else: - q = self.q_a_layernorm(q) - k_nope = self.kv_a_layernorm(k_nope) + if ( + _use_aiter_gfx95 + and self.q_b_proj.weight.dtype == torch.float8_e4m3fn + ): + + q, _, k_nope, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=False, + ) + + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) # q_lora needed by indexer if self.use_nsa: @@ -1557,6 +1621,7 @@ def forward_absorb_prepare( q_nope_out = q_nope_out[:, :expected_m, :] elif _is_hip: # TODO(haishaw): add bmm_fp8 to ROCm + if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8: x = q_nope.transpose(0, 1) q_nope_out = torch.empty( @@ -1839,8 +1904,27 @@ def forward_npu_sparse_prepare( self.kv_a_layernorm.variance_epsilon, ) else: - q = self.q_a_layernorm(q) - k_nope = self.kv_a_layernorm(k_nope) + if ( + _use_aiter_gfx95 + and self.q_b_proj.weight.dtype == torch.float8_e4m3fn + ): + + q, _, k_nope, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=False, + ) + + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) q_lora = q.clone() # required for topk_indices k_nope = k_nope.unsqueeze(1) @@ -1994,6 +2078,7 @@ def forward_absorb_fused_mla_rope_prepare( q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) + q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: @@ -2469,6 +2554,16 @@ def forward( and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8 else "" ) + quant_format = ( + "fp8" + if _is_gfx95_supported + and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None + and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None) + is not None + and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype + == torch.float8_e4m3fn + else "" + ) hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.hip b/sgl-kernel/csrc/allreduce/quick_all_reduce.hip new file mode 100644 index 000000000000..de07cadce10b --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.hip @@ -0,0 +1,112 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include +#include +#include + +#ifdef USE_ROCM + +#include "quick_all_reduce_hip.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size) { + if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce( + quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } + } else { + throw std::runtime_error("quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + +#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h b/sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h new file mode 100644 index 000000000000..f81e330f5c19 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h @@ -0,0 +1,234 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include + +#include + +#include "quick_all_reduce.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot( + T const* A, + T* B, + uint32_t N, + uint32_t num_blocks, + int rank, + uint8_t** dbuffer_list, + uint32_t data_offset, + uint32_t flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + int64_t kMaxProblemSize = static_cast(std::numeric_limits::max()) + 1; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { + destroy(); + } + + void init(int world_size, int rank, std::optional max_problem_size = std::nullopt) { + destroy(); + this->world_size = world_size; + this->rank = rank; + if (max_problem_size.has_value() && max_problem_size.value() > 0) { + this->kMaxProblemSize = max_problem_size.value(); + } + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static int64_t data_buffer_size = 2 * this->kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { + return world_size; + } + int get_rank() { + return rank; + } + bool status() { + return initialized; + } + hipIpcMemHandle_t const get_handle() { + return buffer_ipc_handle; + } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK( + hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(hipGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/elementwise/activation.hip b/sgl-kernel/csrc/elementwise/activation.hip new file mode 100644 index 000000000000..25cc9774e630 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/activation.hip @@ -0,0 +1,172 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#ifndef USE_ROCM + +#include + +#include "utils_hip.h" + +#else +#include "hip/hip_act_and_mul_hip.cuh" +#endif + +// Adapted from flashinfer activation +// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44 + +namespace detail { + +template +__device__ __forceinline__ float to_f32(const T& x) { +#if USE_ROCM + return castToFloat(x); +#else + return static_cast(x); +#endif +} + +template +__device__ __forceinline__ T from_f32(float f32) { +#if USE_ROCM + return castFromFloat(f32); +#else + return static_cast(f32); +#endif +} + +} // namespace detail + +template +__device__ __forceinline__ T silu(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); +} + +template +__device__ __forceinline__ T gelu(const T& x) { + constexpr float kAlpha = M_SQRT1_2; + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); +} + +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +__device__ __forceinline__ T gelu_quick_act(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); +} + +template +__device__ __forceinline__ T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = detail::to_f32(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return detail::from_f32(f32_val * cdf); +} + +void silu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); +#if USE_ROCM + hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); +#if USE_ROCM + hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); +#if USE_ROCM + hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + + return true; + }); +} + +#if USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input) { + int d = input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); + hipLaunchKernelGGL(( sgl_hip::activation::act_only_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + + return true; + }); +} +#endif diff --git a/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_hip.hip b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_hip.hip new file mode 100644 index 000000000000..477944a1c178 --- /dev/null +++ b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_hip.hip @@ -0,0 +1,272 @@ +// !!! This is a file automatically generated by hipify!!! +// Adapted from +// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu + +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// clang-format off +#include +#include +#include +#include +#include + + +#if !defined(USE_ROCM) && (!defined(TORCH_HIP_VERSION) || TORCH_HIP_VERSION < 12040) +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { + TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace"); +} +#else + +#ifndef CUDART_INF_FP16 +#ifndef USE_ROCM +#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) +#endif +#endif + +#ifndef CUDART_INF_BF16 +#ifndef USE_ROCM +#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) +#endif +#endif + +constexpr int32_t BITS_PER_BLOCK = 32; +constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; + +template +__device__ T NegativeInfinity() { + return -INFINITY; +} + +template <> +__device__ __half NegativeInfinity<__half>() { +#ifdef USE_ROCM + return __float2half(-INFINITY); +#else + return -CUDART_INF_FP16; +#endif +} + +template <> +__device__ __hip_bfloat16 NegativeInfinity<__hip_bfloat16>() { +#ifdef USE_ROCM + return __hip_bfloat16(-INFINITY); +#else + return -CUDART_INF_BF16; +#endif +} + +template +__device__ PackedT PackedNegativeInfinity() { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + T packed[kAlignment]; +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + packed[i] = NegativeInfinity(); + } + return *reinterpret_cast(packed); +} + +template +__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride) { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + constexpr uint32_t kPackedMask = (1 << kAlignment) - 1; + + const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y]; + + const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread; + T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset; + const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK; + const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment); + T logits_reg[kAlignment]; + +#pragma unroll + for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread; + offset += THREADS_PER_THREAD_BLOCK * kAlignment) { + if (block_offset + offset >= vocab_size) { + break; + } + + const uint32_t bitmask_val = + (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask; + + if (bitmask_val == 0) { + continue; + } + + if (bitmask_val == kPackedMask) { + *reinterpret_cast(logits_gmem_ptr + offset) = PackedNegativeInfinity(); + continue; + } + + *reinterpret_cast(logits_reg) = *reinterpret_cast(logits_gmem_ptr + offset); +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + if (((bitmask_val >> i) & 1)) { + logits_reg[i] = NegativeInfinity(); + } + } + *reinterpret_cast(logits_gmem_ptr + offset) = *reinterpret_cast(logits_reg); + } +} + +template ::value>> +constexpr auto CeilDiv(T numerator, T denominator) { + return (numerator + denominator - 1) / denominator; +} + +template +void ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); + const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); + + const dim3 block(THREADS_PER_THREAD_BLOCK); + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + if (num_bits_per_thread <= 4 && kAlignment <= 4) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else if (num_bits_per_thread <= 8 && kAlignment <= 8) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else if (num_bits_per_thread <= 16 && kAlignment <= 16) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } +} + +template +void ApplyTokenBitmaskInplaceDispatchToPackedT( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) { + if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } else { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } +} + +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor."); + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous."); + TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor."); + std::pair logits_shape = + logits.dim() == 2 ? std::make_pair(static_cast(logits.size(0)), static_cast(logits.size(1))) + : std::make_pair(1, static_cast(logits.size(0))); + + TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor."); + TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous."); + TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor."); + std::pair bitmask_shape = + bitmask.dim() == 2 ? std::make_pair(static_cast(bitmask.size(0)), static_cast(bitmask.size(1))) + : std::make_pair(1, static_cast(bitmask.size(0))); + + TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32."); + + TORCH_CHECK( + (logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second, + "The provided logits's vocab size should be no less than the bitmask's vocab size " + "(converted from bitmask size). But got vocab size ", + logits_shape.second, + " vs bitmask size ", + bitmask_shape.second); + + int vocab_size = ::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK); + + int32_t num_rows = logits_shape.first; + int32_t* indices_ptr = nullptr; + if (indices) { + TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor."); + TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous."); + TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor."); + TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32."); + num_rows = indices->size(0); + indices_ptr = indices->data_ptr(); + } else { + TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size."); + } + + switch (logits.scalar_type()) { + case torch::kFloat32: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + logits.data_ptr(), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + case torch::kFloat16: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + reinterpret_cast<__half*>(logits.data_ptr()), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + case torch::kBFloat16: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + reinterpret_cast<__hip_bfloat16*>(logits.data_ptr()), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + default: + TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); + break; + } +} +#endif +// clang-format on diff --git a/sgl-kernel/csrc/kvcacheio/transfer.hip b/sgl-kernel/csrc/kvcacheio/transfer.hip new file mode 100644 index 000000000000..1ea658b36973 --- /dev/null +++ b/sgl-kernel/csrc/kvcacheio/transfer.hip @@ -0,0 +1,575 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#include +#include + +#include + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#include "utils_hip.h" // WARP_SIZE +#endif + +__device__ __forceinline__ void +transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { + const uint64_t* __restrict__ src = static_cast(src_addr); + uint64_t* __restrict__ dst = static_cast(dst_addr); + const int total_chunks = item_size_bytes / sizeof(uint64_t); + +#pragma unroll + for (int j = lane_id; j < total_chunks; j += WARP_SIZE) { +#ifndef USE_ROCM + uint64_t tmp; + asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory"); + asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory"); + +#else + uint64_t tmp = __builtin_nontemporal_load(src + j); + __builtin_nontemporal_store(tmp, dst + j); +#endif + } +} + +template +__device__ __forceinline__ T* get_global_offset_lf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t layer_dim, + int64_t page_id, + int64_t item_size_bytes) { + // layer first + return base + layer_id * layer_dim + page_id * item_size_bytes; +} + +template +__device__ __forceinline__ T* get_global_offset_pf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t page_dim, + int64_t page_id, + int64_t item_size_bytes) { + // page first + return base + page_id * page_dim + layer_id * item_size_bytes; +} + +// get offset from layer base table when layers are not contiguous +template +__device__ __forceinline__ T* get_global_offset_lf_tbl( + T* /*unused*/, + const uintptr_t* __restrict__ layer_base_tbl, + int64_t layer_id, + int64_t /*unused*/, + int64_t page_id, + int64_t item_size_bytes) { + return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes; +} + +template +__global__ void transfer_kernel_impl( + const void* __restrict__ src_k, + void* __restrict__ dst_k, + const void* __restrict__ src_v, + void* __restrict__ dst_v, + const int64_t* __restrict__ src_indices, + const int64_t* __restrict__ dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t num_items, + int64_t items_per_warp, + int64_t item_size_bytes, + int64_t src_layout_dim, + int64_t dst_layout_dim, + const uintptr_t* __restrict__ src_k_layer_tbl, + const uintptr_t* __restrict__ dst_k_layer_tbl, + const uintptr_t* __restrict__ src_v_layer_tbl, + const uintptr_t* __restrict__ dst_v_layer_tbl) { + int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t lane_id = tid % WARP_SIZE; + int32_t warp_id = tid / WARP_SIZE; + + for (int i = 0; i < items_per_warp; ++i) { + int64_t item_id = warp_id * items_per_warp + i; + if (item_id >= num_items) { + break; + } + const int64_t src_page_id = src_indices[item_id]; + const int64_t dst_page_id = dst_indices[item_id]; + + // Loop over layers if necessary + for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { + const char* src_ptr = SrcOffsetFn( + static_cast(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_ptr = DstOffsetFn( + static_cast(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes); + + if constexpr (!IsMLA) { + const char* src_v_ptr = SrcOffsetFn( + static_cast(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_v_ptr = DstOffsetFn( + static_cast(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes); + } + } + } +} + +template +void transfer_kv_launcher( + const at::Tensor& src_k, + at::Tensor& dst_k, + const at::Tensor& src_v, + at::Tensor& dst_v, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t item_size, + int64_t src_layout_dim, + int64_t dst_layout_dim, + const at::Tensor& src_k_layers, + const at::Tensor& dst_k_layers, + const at::Tensor& src_v_layers, + const at::Tensor& dst_v_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); + TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); + TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); + TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8"); + + auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; }; + const int64_t num_items = src_indices.numel(); + const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); + const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); + dim3 grid_dim(num_blocks, 1, 1); + const int32_t threads_per_block = num_warps_per_block * WARP_SIZE; + + const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; + void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; + const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr(); + void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr(); + const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr() : nullptr; + const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr() : nullptr; + const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr(); + const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr(); + + hipStream_t torch_current_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + hipLaunchKernelGGL(( transfer_kernel_impl), dim3(grid_dim), dim3(threads_per_block), 0, torch_current_stream, + src_k_ptr, + dst_k_ptr, + src_v_ptr, + dst_v_ptr, + src_indices.data_ptr(), + dst_indices.data_ptr(), + start_layer_id, + num_layers_to_process, + num_items, + items_per_warp, + item_size, + src_layout_dim, + dst_layout_dim, + src_k_tbl_ptr, + dst_k_tbl_ptr, + src_v_tbl_ptr, + dst_v_tbl_ptr); + C10_HIP_KERNEL_LAUNCH_CHECK(); +} + +void transfer_kv_per_layer( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + 0, + 1, + item_size, + 0, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_pf_lf( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t layer_id, + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + layer_id, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer( + const at::Tensor src_k_layers, + const at::Tensor dst_k_layers, + const at::Tensor src_v_layers, + const at::Tensor dst_v_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, false>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + 0, + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_lf_pf( + const at::Tensor src_k_layers, + at::Tensor dst_k, + const at::Tensor src_v_layers, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, false>( + empty, + dst_k, + empty, + dst_v, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_k_layers, + empty, + src_v_layers, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_mla( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( + src, + dst, + empty, + empty, + src_indices, + dst_indices, + 0, + 1, + item_size, + 0, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_mla_pf_lf( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t layer_id, + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( + src, + dst, + empty, + empty, + src_indices, + dst_indices, + layer_id, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla( + const at::Tensor src_layers, + const at::Tensor dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, true>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + 0, + src_layers, + dst_layers, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla_lf_pf( + const at::Tensor src_layers, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, true>( + empty, + dst, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_layers, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +inline void transfer_page_direct( + const at::Tensor src_buffer, + at::Tensor dst_buffer, + int64_t src_page_index, + int64_t dst_page_index, + int64_t page_size) { + dst_buffer.slice(0, dst_page_index, dst_page_index + page_size) + .copy_( + src_buffer.slice(0, src_page_index, src_page_index + page_size), + /* non_blocking= */ true); +} + +void transfer_kv_direct( + const std::vector& src_layers, + std::vector dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size) { + TORCH_CHECK( + src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers"); + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(page_size > 0, "Page size must be positive"); + TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); + + auto src_indices_cpu = src_indices.cpu(); + auto dst_indices_cpu = dst_indices.cpu(); + + const auto num_indices = src_indices_cpu.numel(); + const int64_t num_layers = src_layers.size(); + int64_t* src_indices_ptr = src_indices_cpu.data_ptr(); + int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr(); + + int64_t start_index = 0; + int64_t end_index = 0; + + for (int64_t i = 0; i < num_indices; ++i) { + if (i < num_indices - 1) { + auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i]; + auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i]; + + if (src_diff == 1 && dst_diff == 1) { + continue; + } + end_index = i + 1; + } else { // last batch + end_index = num_indices; + } + auto src_index = src_indices_ptr[start_index]; + auto dst_index = dst_indices_ptr[start_index]; + auto num_tokens = end_index - start_index; + + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens); + } + start_index = end_index; + } +} + +template +inline void transfer_kv_page_first_direct_impl( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t start_layer_id, + int64_t page_size) { + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(page_size > 0, "Page size must be positive"); + TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); + + auto src_indices_cpu = src_indices.cpu(); + auto dst_indices_cpu = dst_indices.cpu(); + const int64_t num_pages = src_indices_cpu.size(0) / page_size; + + if constexpr (IsLf2Pf) { + const bool is_mla = dst_ptrs.size() == 1; + const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item(); + auto d_index = dst_indices_cpu[i * page_size].item() / page_size; + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct( + src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); + if (!is_mla) { + transfer_page_direct( + src_ptrs[j + num_layers], + dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j), + s_index, + 0, + page_size); + } + } + } + } else { + const bool is_mla = src_ptrs.size() == 1; + const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item() / page_size; + auto d_index = dst_indices_cpu[i * page_size].item(); + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct( + src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size); + if (!is_mla) { + transfer_page_direct( + src_ptrs[1].select(0, s_index).select(0, start_layer_id + j), + dst_ptrs[j + num_layers], + 0, + d_index, + page_size); + } + } + } + } +} + +void transfer_kv_per_layer_direct_pf_lf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t layer_id, + int64_t page_size) { + transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size); +} + +void transfer_kv_all_layer_direct_lf_pf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t page_size) { + transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); +} diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.hip b/sgl-kernel/csrc/moe/moe_align_kernel.hip new file mode 100644 index 000000000000..a3e231dc986e --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_align_kernel.hip @@ -0,0 +1,366 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include + +#include "utils_hip.h" + +#define VEC_SIZE 4 +using Vec = int4; + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +#ifdef __CUDA_ARCH__ +__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) { + int original = v; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(mask, v, offset); + if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; + } + return v - original; +} +#endif + +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + bool pad_sorted_token_ids, + const int32_t scan_size) { + extern __shared__ int32_t smem[]; + int32_t* shared_counts = smem; // [num_experts] + int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] + int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + __shared__ int32_t s_total_tokens_post_pad; + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + if (tid < num_experts) { + shared_counts[tid] = 0; + } + + __syncthreads(); + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i] + 1; + atomicAdd(&shared_counts[expert_id], 1); + } + + __syncthreads(); + + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } + +#ifndef __CUDA_ARCH__ // HIP + + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; + } + + __syncthreads(); + + // Blelloch scan + int offset = 1; +#pragma unroll + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; + } + offset <<= 1; + __syncthreads(); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + __syncthreads(); + +#pragma unroll + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + __syncthreads(); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + +#else // CUDA + + // Intra warp prefix sum + int32_t* warp_sums = scan_buf + scan_size; // [<= 32] + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid & (WARP_SIZE - 1); + const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE; + const int warp_sum = warp_exclusive_scan(padded_count) + padded_count; + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum; + __syncthreads(); + + // warp0 accumulate all the block's prefix sum + if (tid < WARP_SIZE) { + int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0; + int incl = warp_exclusive_scan(val) + val; + warp_sums[tid] = incl; + } + __syncthreads(); + + // Every thread obtains the whole block's sum + if (tid == 0) { + prefix[num_experts] = warp_sums[num_warps_for_scan - 1]; + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + + // Fill 0 to scan_buf extended area (tid >= num_expert) + if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0; + __syncthreads(); + + // Perform 2 level exclusive-prefix-sum to scan_buf + int v = (tid < scan_size) ? scan_buf[tid] : 0; + int pre = warp_exclusive_scan(v); + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v; + __syncthreads(); + + if (warp_id == 0) { + int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0; + warp_sums[lane_id] = warp_exclusive_scan(val); + } + __syncthreads(); + + int offset = warp_sums[warp_id]; + if (tid < scan_size) scan_buf[tid] = pre + offset; + __syncthreads(); + + // Write prefix[0..num_experts - 1] and cumsum + if (tid < num_experts) prefix[tid] = scan_buf[tid]; +#endif + + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + // fill expert_ids + const int32_t num_blocks = s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 2; + } + + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } +} + +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids) { + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + tokens_cnts[threadIdx.x] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x]; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x - 1; + } + } + + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + + __syncthreads(); + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[threadIdx.x * num_experts + expert_id]; + } +} + +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor cumsum_buffer, + bool pad_sorted_token_ids) { + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + int threads = 1024; + + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel; + hipLaunchKernelGGL(( small_batch_expert_kernel), dim3(1), dim3(threads), shared_mem_size, stream, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + pad_sorted_token_ids); + } else { + auto align_kernel = moe_align_block_size_kernel; + + const size_t scan_size = next_pow2(num_experts); + const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t); + hipLaunchKernelGGL(( align_kernel), dim3(1), dim3(threads), shared_mem_size, stream, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr(), + pad_sorted_token_ids, + scan_size); + + const int block_threads = ::min(256, (int)threads); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = ::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel()); + } + }); +} diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip new file mode 100644 index 000000000000..71e2b0349d36 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip @@ -0,0 +1,593 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu +// which is originally adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#ifndef USE_ROCM +#include +#include +#include +#else +#include +#include +#endif + +#include "utils_hip.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +// Define reduction operators based on CUDA version +// CUDA 13 (12.9+) deprecated hipcub::Max/Min in favor of cuda::maximum/minimum +#if TORCH_HIP_VERSION >= 12090 +using MaxReduceOp = cuda::maximum<>; +using MinReduceOp = cuda::minimum<>; +#else +using MaxReduceOp = hipcub::Max; +using MinReduceOp = hipcub::Min; +#endif + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N> +class alignas(Alignment) AlignedArray { + T data[N]; +}; + +// ========================== Util functions to convert types ========================== +template +__device__ float convert_to_float(T x) { + if constexpr (std::is_same_v) { + return __half2float(x); + } else if constexpr (std::is_same_v) { + return __bfloat162float(x); + } else if constexpr (std::is_same_v) { + return x; + } else { + return static_cast(x); + } +} + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(convert_to_float(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp()); + + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((convert_to_float(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Sum(threadData); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = exp((convert_to_float(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_softmax, + const bool* finished, + float* output, + int* indices, + const int num_experts, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize) { + using cub_kvp = hipcub::KeyValuePair; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + hipcub::ArgMax arg_max; + + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + float row_sum_for_renormalize = 0; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + row_sum_for_renormalize += result_kvp.value; + } + __syncthreads(); + } + + if (renormalize && threadIdx.x == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( + const T* input, + const bool* finished, + float* output, + const int num_rows, + int* indices, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + T row_chunk_temp[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_temp); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + float row_chunk[VPT]; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = convert_to_float(row_chunk_temp[ii]); + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float row_sum_for_renormalize = 0; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW); + int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + row_sum_for_renormalize += max_val; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } + + // Fuse renormalization of topk_weights into this kernel + if (renormalize && thread_group_idx == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper( + const T* input, + const bool* finished, + float* output, + int* indices, + const int num_rows, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize, + hipStream_t stream) { + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + hipLaunchKernelGGL(( topkGatingSoftmax), dim3(num_blocks), dim3(block_dim), 0, stream, + input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize); +} + +#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream); + +template +void topkGatingSoftmaxKernelLauncher( + const T* gating_output, + float* topk_weights, + int* topk_indices, + float* softmax_workspace, + const int num_tokens, + const int num_experts, + const int topk, + const bool renormalize, + hipStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK( + softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + hipLaunchKernelGGL(( moeSoftmax), dim3(num_tokens), dim3(TPB), 0, stream, gating_output, nullptr, softmax_workspace, num_experts); + hipLaunchKernelGGL(( moeTopK), dim3(num_tokens), dim3(TPB), 0, stream, + softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize); + } + } +} + +void topk_softmax( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& gating_output, + const bool renormalize) // [num_tokens, num_experts] +{ + // Check data type + TORCH_CHECK( + gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half || + gating_output.scalar_type() == at::ScalarType::BFloat16, + "gating_output must be float32, float16, or bfloat16"); + + // Check dimensions + TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]"); + TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]"); + TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]"); + + // Check shapes + TORCH_CHECK( + gating_output.size(0) == topk_weights.size(0), + "First dimension of topk_weights must match num_tokens in gating_output"); + TORCH_CHECK( + gating_output.size(0) == topk_indices.size(0), + "First dimension of topk_indices must match num_tokens in gating_output"); + TORCH_CHECK( + topk_weights.size(-1) == topk_indices.size(-1), + "Second dimension of topk_indices must match topk in topk_weights"); + TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts"); + + const int num_experts = static_cast(gating_output.size(-1)); + const int num_tokens = static_cast(gating_output.size(0)); + const int topk = static_cast(topk_weights.size(-1)); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + torch::Tensor softmax_workspace = + torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float)); + + const at::ScalarType dtype = gating_output.scalar_type(); + if (dtype == at::ScalarType::Float) { + topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + stream); + } else if (dtype == at::ScalarType::Half) { + topkGatingSoftmaxKernelLauncher<__half>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + stream); + } else if (dtype == at::ScalarType::BFloat16) { + topkGatingSoftmaxKernelLauncher<__hip_bfloat16>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype); + } +} diff --git a/sgl-kernel/csrc/speculative/eagle_utils.hip b/sgl-kernel/csrc/speculative/eagle_utils.hip new file mode 100644 index 000000000000..63bd7bbbc312 --- /dev/null +++ b/sgl-kernel/csrc/speculative/eagle_utils.hip @@ -0,0 +1,410 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) 2025 by SGLang team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#ifndef USE_ROCM +#include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#endif + +typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode; + +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = +// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, +// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] +__global__ void build_tree_efficient( + int64_t* parent_list, + int64_t* selected_index, + int64_t* verified_seq_len, + bool* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num, + int tree_mask_mode) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for (int i = 0; i < bid; i++) { + seq_tree_idx += verified_seq_len[i] * draft_token_num; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx; + if (tree_mask_mode == FULL_MASK) { + token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; + } else { + token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1; + } + tree_mask[token_tree_idx - 1] = true; + for (int i = 0; i < draft_token_num - 1; i++) { + tree_mask[token_tree_idx + i] = false; + } + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + + int retrive_index_offset = bid * draft_token_num; + for (int i = draft_token_num - 1; i > 0; --i) { + int current_token_idx = retrive_index_offset + i; + retrive_index[bid * draft_token_num + i] = current_token_idx; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk; + int parent_position = 0; + if (parent_tb_idx > 0) { + int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (; parent_position < draft_token_num; ++parent_position) { + if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) { + ++parent_position; + break; + } + } + } + if (parent_position == draft_token_num) { + printf( + "WARNING: invalid eagle tree!!! Detected a token with no parent token selected. " + "Please check if the logprob has nan. The token will be ignored to keep proceeding.\n"); + continue; + } + + if (retrive_next_token[bid * draft_token_num + parent_position] == -1) { + retrive_next_token[bid * draft_token_num + parent_position] = i; + } else { + int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position]; + retrive_next_token[bid * draft_token_num + parent_position] = i; + retrive_next_sibling[bid * draft_token_num + i] = origin_next_token; + } + } + retrive_index[bid * draft_token_num] = bid * draft_token_num; + } else { + int cur_position = tid - 1; + while (true) { + position += 1; + tree_mask[token_tree_idx + cur_position] = true; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; ++cur_position) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + } +} + +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item] +// positions [bs * draft_token] +// retrive_index [bs, draft_token] +// retrive_next_token [bs, draft_token] +// retrive_next_sibling [bs, draft_token] +__global__ void build_tree_efficient_partial_packed( + int64_t* parent_list, + int64_t* selected_index, + int64_t* verified_seq_len, + uint8_t* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num, + size_t num_bytes_per_item) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item; + tree_mask[token_tree_idx] = 1; // little endian + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + + int retrive_index_offset = bid * draft_token_num; + for (int i = draft_token_num - 1; i > 0; --i) { + int current_token_idx = retrive_index_offset + i; + retrive_index[bid * draft_token_num + i] = current_token_idx; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk; + int parent_position = 0; + if (parent_tb_idx > 0) { + int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (; parent_position < draft_token_num; ++parent_position) { + if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) { + ++parent_position; + break; + } + } + } + if (parent_position == draft_token_num) { + printf( + "WARNING: invalid eagle tree!!! Detected a token with no parent token selected. " + "Please check if the logprob has nan. The token will be ignored to keep proceeding.\n"); + continue; + } + + if (retrive_next_token[bid * draft_token_num + parent_position] == -1) { + retrive_next_token[bid * draft_token_num + parent_position] = i; + } else { + int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position]; + retrive_next_token[bid * draft_token_num + parent_position] = i; + retrive_next_sibling[bid * draft_token_num + i] = origin_next_token; + } + } + retrive_index[bid * draft_token_num] = bid * draft_token_num; + } else { + int cur_position = tid - 1; + while (true) { + position += 1; + int byte_idx = (cur_position + 1) / 8; + int bit_idx = (cur_position + 1) % 8; + tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx); + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; ++cur_position) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + } +} + +void build_tree_kernel_efficient( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t topk, + int64_t depth, + int64_t draft_token_num, + int64_t tree_mask_mode) { + // TODO (ying) check shape + // TODO (ying) check type + int bs = parent_list.size(0); + dim3 grid(bs); + dim3 block(draft_token_num); + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + if (tree_mask_mode == QLEN_ONLY_BITPACKING) { + size_t num_bytes_per_item = 1; + if (draft_token_num > 16) { + num_bytes_per_item = 4; + } else if (draft_token_num > 8) { + num_bytes_per_item = 2; + } + hipLaunchKernelGGL(( build_tree_efficient_partial_packed), dim3(grid), dim3(block), 0, stream, + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num), + num_bytes_per_item); + } else { + hipLaunchKernelGGL(( build_tree_efficient), dim3(grid), dim3(block), 0, stream, + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num), + int32_t(tree_mask_mode)); + } +} + +template +__global__ void VerifyTreeGreedy( + IdType* predicts, + IdType* accept_index, + IdType* accept_token_num, // mutable + IdType2* candidates, + IdType2* retrive_index, + IdType2* retrive_next_token, + IdType2* retrive_next_sibling, + IdType2* target_predict, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens) { + uint32_t bx = blockIdx.x; + + IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; + accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; + uint32_t num_accepted_tokens = 0; + IdType2 cur_index = 0; + + for (uint32_t j = 1; j < num_speculative_tokens; ++j) { + cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; + while (cur_index != -1) { + IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index]; + IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index]; + IdType2 target_token_id = target_predict[last_accepted_retrive_idx]; + + if (draft_token_id == target_token_id) { + // accept token + predicts[last_accepted_retrive_idx] = target_token_id; + ++num_accepted_tokens; + accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index; + last_accepted_retrive_idx = draft_index; + break; + } else { + cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index]; + } + } + if (cur_index == -1) break; + } + accept_token_num[bx] = num_accepted_tokens; + predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx]; +} + +// predicts: [tot_num_draft_tokens] +// accept_index: [bs, num_spec_step] +// accept_token_num: [bs] +// candidates: [bs, num_draft_tokens] +// retrive_index: [bs, num_draft_tokens] +// retrive_next_token: [bs, num_draft_tokens] +// retrive_next_sibling: [bs, num_draft_tokens] +// target_predict: [bs, num_draft_tokens] +void verify_tree_greedy( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor target_predict, + int64_t cuda_stream = 0) { + CHECK_INPUT(candidates); + CHECK_INPUT(retrive_index); + CHECK_INPUT(retrive_next_token); + CHECK_INPUT(retrive_next_sibling); + CHECK_INPUT(target_predict); + auto device = target_predict.device(); + CHECK_EQ(candidates.device(), device); + CHECK_EQ(retrive_index.device(), device); + CHECK_EQ(retrive_next_token.device(), device); + CHECK_EQ(retrive_next_sibling.device(), device); + CHECK_EQ(target_predict.device(), device); + CHECK_DIM(1, predicts); + CHECK_DIM(2, accept_index); + CHECK_DIM(1, accept_token_num); + CHECK_DIM(2, candidates); + CHECK_DIM(2, retrive_index); + CHECK_DIM(2, retrive_next_token); + CHECK_DIM(2, retrive_next_sibling); + CHECK_DIM(2, target_predict); + unsigned int batch_size = candidates.size(0); + unsigned int num_spec_step = accept_index.size(1); + unsigned int num_draft_tokens = candidates.size(1); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + CHECK_EQ(batch_size, retrive_index.size(0)); + CHECK_EQ(batch_size, retrive_next_token.size(0)); + CHECK_EQ(batch_size, retrive_next_sibling.size(0)); + CHECK_EQ(batch_size, target_predict.size(0)); + CHECK_EQ(num_draft_tokens, retrive_index.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_token.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1)); + CHECK_EQ(num_draft_tokens, target_predict.size(1)); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + if (predicts.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32)."); + } + if (accept_index.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32)."); + } + if (accept_token_num.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32)."); + } + if (candidates.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64)."); + } + if (retrive_index.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64)."); + } + if (retrive_next_token.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64)."); + } + if (retrive_next_sibling.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64)."); + } + if (target_predict.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64)."); + } + + hipStream_t stream = reinterpret_cast(cuda_stream); + dim3 grid(batch_size); + dim3 block(1); + + hipLaunchKernelGGL(( VerifyTreeGreedy), dim3(grid), dim3(block), 0, stream, + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + static_cast(target_predict.data_ptr()), + batch_size, + num_spec_step, + num_draft_tokens); +} diff --git a/sgl-kernel/include/hip/hip_act_and_mul_hip.cuh b/sgl-kernel/include/hip/hip_act_and_mul_hip.cuh new file mode 100644 index 000000000000..4b71d9213b20 --- /dev/null +++ b/sgl-kernel/include/hip/hip_act_and_mul_hip.cuh @@ -0,0 +1,89 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "utils_hip.h" + +#define kBitsToLoad 128 +#define kBytesToLoad (kBitsToLoad / 8) + +// Adapted from +// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29) + +namespace sgl_hip { +namespace activation { + +template +__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); + y_vec.cast_load(input + offset + d + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x) * y; + } +} + +template +__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]); + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x); + } +} + +} // namespace activation +} // namespace sgl_hip diff --git a/sgl-kernel/include/utils_hip.h b/sgl-kernel/include/utils_hip.h new file mode 100644 index 000000000000..1c7f2f254867 --- /dev/null +++ b/sgl-kernel/include/utils_hip.h @@ -0,0 +1,470 @@ +// !!! This is a file automatically generated by hipify!!! +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#ifdef USE_ROCM +#include +#endif + +#ifdef USE_ROCM +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = __half; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = __hip_bfloat16; \ + return __VA_ARGS__(); \ + } +#endif // USE_ROCM + +#ifndef USE_ROCM +// Adapt from FlashInfer +#ifdef FLASHINFER_ENABLE_F16 +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_F16(c_type, ...) +#endif // FLASHINFER_ENABLE_F16 + +#ifdef FLASHINFER_ENABLE_BF16 +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_BF16(c_type, ...) +#endif // FLASHINFER_ENABLE_BF16 + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) +#endif // FLASHINFER_ENABLE_FP8_E4M3 + +#ifdef FLASHINFER_ENABLE_FP8_E5M2 +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) +#endif // FLASHINFER_ENABLE_FP8_E5M2 + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \ + << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ + TORCH_CHECK( \ + num_qo_heads % num_kv_heads == 0, \ + "num_qo_heads(", \ + num_qo_heads, \ + ") must be divisible by num_kv_heads(", \ + num_kv_heads, \ + ")") + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} +#endif // USE_ROCM + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + hipError_t e = cmd; \ + if (e != hipSuccess) { \ + std::stringstream _message; \ + auto s = hipGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) + +inline int getSMVersion() { + int device{-1}; + CHECK_CUDA_SUCCESS(hipGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_major, hipDeviceAttributeComputeCapabilityMajor, device)); + CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_minor, hipDeviceAttributeComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline bool isDeviceType(const std::string& device_type) { + int deviceCount; + CHECK_CUDA_SUCCESS(hipGetDeviceCount(&deviceCount)); + + int device_id = -1; + if (deviceCount >= 1) { + CHECK_CUDA_SUCCESS(hipGetDevice(&device_id)); + } else { + return false; + } + + hipDeviceProp_t prop; + CHECK_CUDA_SUCCESS(hipGetDeviceProperties(&prop, device_id)); + if (device_type == std::string(prop.name)) { + return true; + } + return false; +} + +inline bool getBoolEnv(char const* name) { + char const* env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +inline bool getEnvEnablePDL() { + static std::once_flag flag; + static bool enablePDL = false; + std::call_once(flag, [&]() { + if (getSMVersion() >= 90) { + // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` + enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); + } + }); + return enablePDL; +} + +// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28 +#ifndef USE_ROCM +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width)) +#else +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) +#endif + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define DISPATCH_CASE_FLOAT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOAT_TYPES(__VA_ARGS__)) + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif +#endif + +#ifdef USE_ROCM + +#include "hip/hip_math_def.h" +#include "hip/hip_vec_dtypes.h" + +#else + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return static_cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return static_cast(val); +} + +#endif + +// add FP8 support +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else // USE_ROCM +#if HIP_FP8_TYPE_FNUZ +#include +using FP8_TYPE = c10::Float8_e4m3fnuz; +constexpr auto FP8_E4M3_MAX = 224.0f; +#else +#if HIP_FP8_TYPE_E4M3 +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#error "fp8 is not supported in this processor (arch < gfx942)." +#endif // HIP_FP8_TYPE_E4M3 +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM + +#define FULL_MASK 0xffffffff + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1)); + return value; +} + +__device__ __forceinline__ float blockReduceMax(float value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + value = warpReduceMax(value); + + if (laneId == 0) warpLevelMaxs[warpId] = value; + __syncthreads(); + + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); + + return value; +} + +// Pads to a multiple of `alignment` rows. +inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size + + if (pad_rows == 0) { + return tensor; // Already aligned + } + + torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); + torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows + + // Ensure column-major layout + if (is_column_major) { + return tensor_padded.t().contiguous().t(); + } + return tensor_padded; +} + +// Get the next power of 2 of a number +inline uint32_t next_pow2(uint32_t x) noexcept { + if (x <= 1) return 1; + return 1u << (32 - __builtin_clz(x - 1)); +} + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 5a1637acf8a3..a0298a587bc8 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,10 +1,11 @@ [build-system] requires = [ + "setuptools>=75.0", "scikit-build-core>=0.10", "torch>=2.8.0", "wheel", ] -build-backend = "scikit_build_core.build" +build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" @@ -29,11 +30,3 @@ exclude = [ "dist*", "tests*", ] - -[tool.scikit-build] -cmake.build-type = "Release" -minimum-version = "build-system.requires" - -wheel.py-api = "cp310" -wheel.license-files = [] -wheel.packages = ["python/sgl_kernel"] diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml deleted file mode 100644 index a0298a587bc8..000000000000 --- a/sgl-kernel/pyproject_rocm.toml +++ /dev/null @@ -1,32 +0,0 @@ -[build-system] -requires = [ - "setuptools>=75.0", - "scikit-build-core>=0.10", - "torch>=2.8.0", - "wheel", -] -build-backend = "setuptools.build_meta" - -[project] -name = "sgl-kernel" -version = "0.3.16.post2" -description = "Kernel Library for SGLang" -readme = "README.md" -requires-python = ">=3.10" -license = { file = "LICENSE" } -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Environment :: GPU :: NVIDIA CUDA" -] -dependencies = [] - -[project.urls] -"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" -"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" - -[tool.wheel] -exclude = [ - "dist*", - "tests*", -]